diff --git a/Cargo.lock b/Cargo.lock index 9fa1831..7790c45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "anyhow" +version = "1.0.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" + [[package]] name = "arrayvec" version = "0.7.6" @@ -56,6 +62,12 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.22.1" @@ -146,6 +158,12 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "const-default" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b396d1f76d455557e1218ec8066ae14bba60b4b36ecd55577ba979f5db7ecaa" + [[package]] name = "cortex-m" version = "0.7.7" @@ -605,6 +623,18 @@ dependencies = [ "embedded-io-async", ] +[[package]] +name = "embedded-alloc" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f2de9133f68db0d4627ad69db767726c99ff8585272716708227008d3f1bddd" +dependencies = [ + "const-default", + "critical-section", + "linked_list_allocator", + "rlsf", +] + [[package]] name = "embedded-hal" version = "0.2.7" @@ -934,6 +964,12 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "linked_list_allocator" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afa463f5405ee81cdb9cc2baf37e08ec7e4c8209442b5d72c04cfb2cd6e6286" + [[package]] name = "litrs" version = "0.4.2" @@ -1113,7 +1149,8 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315" name = "pico-website" version = "0.1.0" dependencies = [ - "base64", + "anyhow", + "base64 0.22.1", "cortex-m", "cortex-m-rt", "cyw43", @@ -1127,6 +1164,7 @@ dependencies = [ "embassy-rp", "embassy-sync 0.7.1", "embassy-time", + "embedded-alloc", "embedded-io-async", "heapless", "log", @@ -1355,6 +1393,18 @@ dependencies = [ "portable-atomic-util", ] +[[package]] +name = "rlsf" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "222fb240c3286247ecdee6fa5341e7cdad0ffdf8e7e401d9937f2d58482a20bf" +dependencies = [ + "cfg-if", + "const-default", + "libc", + "svgbobdoc", +] + [[package]] name = "rp-pac" version = "7.0.0" @@ -1560,6 +1610,19 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "svgbobdoc" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2c04b93fc15d79b39c63218f15e3fdffaa4c227830686e3b7c5f41244eb3e50" +dependencies = [ + "base64 0.13.1", + "proc-macro2", + "quote", + "syn 1.0.109", + "unicode-width", +] + [[package]] name = "syn" version = "1.0.109" diff --git a/Cargo.toml b/Cargo.toml index d1bc2fc..415ee26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,4 +68,7 @@ ringbuf = { version = "*", default-features = false, features = [ ], optional = true } percent-encoding = { version = "*", default-features = false } sha1 = { version = "*", default-features = false } -base64 = { version = "*", default-features = false } \ No newline at end of file +base64 = { version = "*", default-features = false } + +anyhow = { version = "*", default-features = false } +embedded-alloc = "*" \ No newline at end of file diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 49011df..48d6318 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,3 @@ [toolchain] -channel = "nightly-2025-03-18" \ No newline at end of file +# channel = "nightly-2025-03-18" +channel = "nightly" \ No newline at end of file diff --git a/src/apps/chat.js b/src/apps/chat.js index 298f6bf..b88b587 100644 --- a/src/apps/chat.js +++ b/src/apps/chat.js @@ -41,7 +41,8 @@ ws.onmessage = (event) => { document.getElementById("send").onsubmit = (event) => { event.preventDefault(); // console.log(event, document.getElementById("sendcontent").value); + let timestamp = Date.now(); let content = document.getElementById("sendcontent"); - ws.send("send " + content.value); + ws.send("send " + timestamp.toString() + " " + content.value); content.value = ""; }; diff --git a/src/apps/chat.rs b/src/apps/chat.rs index 837c899..655428f 100644 --- a/src/apps/chat.rs +++ b/src/apps/chat.rs @@ -1,10 +1,11 @@ -use core::sync::atomic::Ordering; +use core::{str::from_utf8_unchecked, sync::atomic::Ordering}; use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex}; use embassy_time::{Duration, Timer}; use heapless::String; // use log::{info, warn}; // use pico_website::unimplemented; +use anyhow::{Result, anyhow}; use defmt::*; use portable_atomic::AtomicUsize; use serde::Serialize; @@ -19,6 +20,7 @@ const MSG_MAX_SIZE: usize = 500; struct Msg<'a> { id: usize, author: u8, + timestamp: u64, content: &'a str, } @@ -26,8 +28,8 @@ const MSGS_SIZE: usize = 100000; const _: () = core::assert!(MSGS_SIZE > MSG_MAX_SIZE); #[derive(Debug)] struct Msgs { - /// * Memory layout with sizes in bytes : ...|content: len|len: 2|author+1: 1|... - /// * `author=0` means theres no message, it's just padding and should be skipped. + /// * Memory layout with sizes in bytes : ...|content: len|len: 2|delimiter(255)| other_msg/pad(0)... + /// * content = author: 1 | timestamp: 8 | content: len-5 /// * No message is splitted inner: [u8; MSGS_SIZE], /// next byte index @@ -42,20 +44,24 @@ impl Msgs { next_msg: 0, } } - fn push(&mut self, author: u8, content: &str) { - if self.head + content.len() + 3 >= MSGS_SIZE { + fn push(&mut self, author: u8, timestamp: u64, content: &str) { + let len = 1 + 8 + content.len(); + if self.head + len + 3 >= MSGS_SIZE { self.inner[self.head..].fill(0); self.head = 0 } - self.inner[self.head..self.head + content.len()].copy_from_slice(content.as_bytes()); - self.head += content.len(); - self.inner[self.head..self.head + 2].copy_from_slice(&(content.len() as u16).to_le_bytes()); - self.inner[self.head + 2] = author + 1; - self.head += 3; + self.inner[self.head] = author; + self.inner[self.head + 1..self.head + 9].copy_from_slice(×tamp.to_le_bytes()); + self.inner[self.head + 9..self.head + 9 + content.len()] + .copy_from_slice(content.as_bytes()); + self.inner[self.head + len..self.head + len + 2] + .copy_from_slice(&(len as u16).to_le_bytes()); + self.inner[self.head + len + 2] = u8::MAX; + self.head += len + 3; self.next_msg += 1; } /// Iter messages from present to past - fn iter(&self) -> MsgsIter { + fn iter<'a>(&'a self) -> MsgsIter<'a> { if self.head == 0 { MsgsIter { msgs: self, @@ -93,41 +99,86 @@ impl<'a> Iterator for MsgsIter<'a> { self.head = MSGS_SIZE; } let above = self.head > self.msgs.head; - if above && self.head < self.msgs.head + 3 { + if above && self.head <= self.msgs.head + 12 { self.finished = true; return None; } - let author = self.msgs.inner[self.head - 1]; - self.head -= 1; - if author == 0 { - return self.next(); + match self.msgs.inner[self.head - 1] { + // Skip padding + 0 => { + self.head -= 1; + self.next() + } + u8::MAX => { + let len = u16::from_le_bytes([ + self.msgs.inner[self.head - 3], + self.msgs.inner[self.head - 2], + ]) as usize; + self.head -= 3 + len; + let author = self.msgs.inner[self.head]; + let timestamp = u64::from_le_bytes([ + self.msgs.inner[self.head + 1], + self.msgs.inner[self.head + 2], + self.msgs.inner[self.head + 3], + self.msgs.inner[self.head + 4], + self.msgs.inner[self.head + 5], + self.msgs.inner[self.head + 6], + self.msgs.inner[self.head + 7], + self.msgs.inner[self.head + 8], + ]); + let content = unsafe { + from_utf8_unchecked(&self.msgs.inner[self.head + 9..self.head + len]) + }; + let id = self.current_id; + if self.current_id == 0 { + self.finished = true; + } else { + self.current_id -= 1; + } + Some(Msg { + id, + author, + timestamp, + content, + }) + } + _ => core::unreachable!(), } - let author = author - 1; - let len = u16::from_le_bytes([ - self.msgs.inner[self.head - 2], - self.msgs.inner[self.head - 1], - ]) as usize; - self.head -= 2; + // if self.msgs.inner[self.head - 1] == 0 { + // return self.next(); + // } - let content = - unsafe { str::from_utf8_unchecked(&self.msgs.inner[self.head - len..self.head]) }; - self.head -= len; - if above && self.head < self.msgs.head { - self.finished = true; - return None; - } - let id = self.current_id; - if self.current_id == 0 { - self.finished = true; - } else { - self.current_id -= 1; - } + // let author = self.msgs.inner[self.head - 1]; + // self.head -= 1; + // if author == 0 { + // return self.next(); + // } + // let author = author - 1; + // let len = u16::from_le_bytes([ + // self.msgs.inner[self.head - 2], + // self.msgs.inner[self.head - 1], + // ]) as usize; + // self.head -= 2; - Some(Msg { - id, - author, - content, - }) + // let content = + // unsafe { str::from_utf8_unchecked(&self.msgs.inner[self.head - len..self.head]) }; + // self.head -= len; + // if above && self.head < self.msgs.head { + // self.finished = true; + // return None; + // } + // let id = self.current_id; + // if self.current_id == 0 { + // self.finished = true; + // } else { + // self.current_id -= 1; + // } + + // Some(Msg { + // id, + // author, + // content, + // }) } } @@ -191,50 +242,68 @@ impl App for ChatApp { self.usernames_version = 0; self.next_msg = 0; Timer::after_millis(500).await; - let r: Result<(), ()> = try { - loop { - Timer::after_millis(1).await; - { - let uv = USERNAMES_VERSION.load(Ordering::Relaxed); - if self.usernames_version < uv { - ws.send_json(&(*USERNAMES.lock().await)).await?; - } - self.usernames_version = uv; - } - { - let msgs = MSGS.lock().await; - for m in msgs.iter() { - if m.id >= self.next_msg { - ws.send_json(&m).await?; + let r: Result<()> = try { + 'ws: { + loop { + Timer::after_millis(1).await; + { + let uv = USERNAMES_VERSION.load(Ordering::Relaxed); + if self.usernames_version < uv { + ws.send_json(&(*USERNAMES.lock().await)).await?; } + self.usernames_version = uv; } - self.next_msg = msgs.next_msg; - } - if ws.last_msg.elapsed() >= Duration::from_secs(5) { - ws.send(WsMsg::Ping(&[])).await?; - } - while let Some(r) = ws.rcv().await? { - info!("{:?}", r); - if let WsMsg::Text(r) = r { - if r.starts_with("send ") { - if r.len() > 5 + MSG_MAX_SIZE { - warn!("Message too long! (len={})", r.len() - 5); - return; + { + let msgs = MSGS.lock().await; + for m in msgs.iter() { + if m.id >= self.next_msg { + ws.send_json(&m).await?; } - { - MSGS.lock() - .await - .push(self.id, r.get(5..).unwrap_or_default()); + } + self.next_msg = msgs.next_msg; + } + if ws.last_msg.elapsed() >= Duration::from_secs(5) { + ws.send(WsMsg::Ping(&[])).await?; + } + while let Some(r) = ws.rcv().await? { + info!("{:?}", r); + if let WsMsg::Text(r) = r { + match r.split_once(' ') { + Some(("send", m)) => { + let (t, msg) = m + .split_once(' ') + .ok_or(anyhow!("message cannot be splitted with ' '"))?; + info!("{}", t); + let timestamp = t.parse::()?; + if msg.len() > MSG_MAX_SIZE { + warn!("Message too long! (len={})", r.len() - 5); + break 'ws; + } + MSGS.lock().await.push(self.id, timestamp, msg); + } + _ => {} } + // if r.starts_with("send ") { + // if r.len() > 5 + MSG_MAX_SIZE { + // warn!("Message too long! (len={})", r.len() - 5); + // break 'ws; + // } + // { + // MSGS.lock() + // .await + // .push(self.id, r.get(5..).unwrap_or_default()); + // } + // } } } } } }; - if r.is_err() { + if let Err(e) = r { warn!( - "Socket {}: error in ws, terminating connection", - self.socket_name() + "Socket {}: error in ws\n{}", + self.socket_name(), + Display2Format(&e) ); } } diff --git a/src/apps/ttt.rs b/src/apps/ttt.rs index a7a4ca5..a18d372 100644 --- a/src/apps/ttt.rs +++ b/src/apps/ttt.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use core::ops::Not; use core::str::from_utf8_unchecked; use embassy_sync::blocking_mutex::raw::ThreadModeRawMutex; @@ -92,7 +93,7 @@ impl App for TttApp { } async fn handle_ws(&mut self, _path: &str, mut ws: Ws<'_, BUF_SIZE>) { Timer::after_millis(500).await; - let r: Result<(), ()> = try { + let r: Result<()> = try { loop { Timer::after_millis(1).await; let Ok(mut game) = GAME.try_lock() else { @@ -140,10 +141,11 @@ impl App for TttApp { } } }; - if r.is_err() { + if let Err(e) = r { warn!( - "Socket {}: error in ws, terminating connection", - self.socket_name() + "Socket {}: error in ws\n{}", + self.socket_name(), + Display2Format(&e) ); } } diff --git a/src/main.rs b/src/main.rs index 04d12ba..37218c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,9 +6,9 @@ #![feature(try_blocks)] #![feature(impl_trait_in_bindings)] #![feature(array_repeat)] -#![feature(generic_arg_infer)] #![feature(async_iterator)] +use core::mem::MaybeUninit; #[cfg(feature = "wifi-connect")] use core::net::Ipv4Addr; @@ -19,16 +19,15 @@ use embassy_net::{Config, StackResources}; use embassy_rp::bind_interrupts; use embassy_rp::clocks::RoscRng; use embassy_rp::gpio::{Level, Output}; -use embassy_rp::peripherals::USB; use embassy_rp::peripherals::{DMA_CH0, PIO0}; use embassy_rp::pio::{InterruptHandler as PioInterruptHandler, Pio}; -use embassy_rp::usb::{Driver, InterruptHandler as UsbInterruptHandler}; -// use log::info; -// use pico_website::unwrap; -use rand_core::RngCore; +use embedded_alloc::LlffHeap as Heap; use static_cell::StaticCell; use {defmt_rtt as _, panic_probe as _}; +#[global_allocator] +static HEAP: Heap = Heap::empty(); + #[cfg(feature = "dhcp")] mod dhcp; @@ -39,7 +38,7 @@ mod apps; mod socket; bind_interrupts!(struct Irqs { - USBCTRL_IRQ => UsbInterruptHandler; + // USBCTRL_IRQ => UsbInterruptHandler; PIO0_IRQ_0 => PioInterruptHandler; }); @@ -62,8 +61,20 @@ async fn net_task(mut runner: embassy_net::Runner<'static, cyw43::NetDriver<'sta #[embassy_executor::main] async fn main(spawner: Spawner) { + // Init heap + // const HEAP_SIZE: usize = 4096; + // static mut HEAP_MEM: [MaybeUninit; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE]; + // let heap: Heap = Heap::empty(); + // unsafe { heap.init(&raw mut HEAP_MEM as usize, HEAP_SIZE) } + { + use core::mem::MaybeUninit; + const HEAP_SIZE: usize = 1024; + static mut HEAP_MEM: [MaybeUninit; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE]; + unsafe { HEAP.init(&raw mut HEAP_MEM as usize, HEAP_SIZE) } + } + let p = embassy_rp::init(Default::default()); - let driver = Driver::new(p.USB, Irqs); + // let driver = Driver::new(p.USB, Irqs); // spawner.spawn(logger_task(driver)).unwrap(); let mut rng = RoscRng; diff --git a/src/socket.rs b/src/socket.rs index 9a8c614..338a4c4 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -1,3 +1,4 @@ +use anyhow::{Result, anyhow}; use base64::prelude::*; use core::fmt::Write; @@ -51,7 +52,8 @@ pub async fn listen_task< loop { Timer::after_secs(0).await; let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); - socket.set_timeout(Some(Duration::from_secs(10))); + // socket.set_timeout(Some(Duration::from_secs(5))); + // socket.set_keep_alive(Some(Duration::from_secs(1))); info!("Socket {}: Listening on TCP:{}...", app.socket_name(), port); if let Err(e) = socket.accept(port).await { @@ -70,34 +72,31 @@ pub async fn listen_task< let mut path = String::::new(); let mut request_type = HttpRequestType::Get; let mut is_ws = false; - loop { + let r: Result<&str> = 'connection: loop { Timer::after_secs(0).await; match socket .read_with(|msg| { let (headers, content) = match from_utf8(msg) { - Ok(b) => { - info!("{}", b); - match b.split_once("\r\n\r\n") { - Some(t) => t, - None => (b, ""), - } - } - Err(_) => { - warn!("Non utf8 http request"); - return (0, Err(())); + Ok(b) => match b.split_once("\r\n\r\n") { + Some(t) => t, + None => (b, ""), + }, + Err(e) => { + return (0, Err(e.into())); } }; buf.clear(); - if let Err(_) = buf.push_str(content) { - warn!("Received content is bigger than maximum content!"); - return (0, Err(())); + if buf.push_str(content).is_err() { + return ( + 0, + Err(anyhow!("Received content is bigger than maximum content!")), + ); } let mut hl = headers.lines(); match hl.next() { None => { - warn!("Empty request"); - return (0, Err(())); + return (0, Err(anyhow!("Empty request"))); } Some(l1) => { let mut l1 = l1.split(' '); @@ -105,24 +104,20 @@ pub async fn listen_task< Some("GET") => HttpRequestType::Get, Some("POST") => HttpRequestType::Post, Some(t) => { - warn!("Unknown request type : {}", t); - return (0, Err(())); + return (0, Err(anyhow!("Unknown request type : {}", t))); } None => { - warn!("No request type"); - return (0, Err(())); + return (0, Err(anyhow!("No request type"))); } }; path.clear(); if let Err(_) = path.push_str(match l1.next() { Some(path) => path, None => { - warn!("No path"); - return (0, Err(())); + return (0, Err(anyhow!("No path"))); } }) { - warn!("Path is too big!"); - return (0, Err(())); + return (0, Err(anyhow!("Path is too big!"))); } } }; @@ -143,8 +138,7 @@ pub async fn listen_task< } } let Some(host) = host else { - warn!("No host!"); - return (0, Err(())); + return (0, Err(anyhow!("No host!"))); }; info!( "Socket {}: {:?}{} request for {}{}", @@ -157,17 +151,14 @@ pub async fn listen_task< buf.clear(); if is_ws { let Some(key) = ws_key else { - warn!("No ws key!"); - return (0, Err(())); + return (0, Err(anyhow!("No ws key!"))); }; - if let Err(_) = buf.push_str(key) { - warn!("Ws key is too long!"); - return (0, Err(())); + if buf.push_str(key).is_err() { + return (0, Err(anyhow!("Ws key is too long!"))); } } else { - if let Err(_) = buf.push_str(content) { - warn!("Content is too long!"); - return (0, Err(())); + if buf.push_str(content).is_err() { + return (0, Err(anyhow!("Content is too long!"))); } } (msg.len(), Ok(())) @@ -175,10 +166,9 @@ pub async fn listen_task< .await { Ok(Ok(())) => {} - Ok(Err(())) => break, + Ok(Err(e)) => break 'connection Err(e), Err(e) => { - warn!("Error while receiving : {:?}", e); - break; + break 'connection Ok("connection reset"); } }; @@ -186,7 +176,11 @@ pub async fn listen_task< let res_content: Result, core::fmt::Error> = try { if is_ws { if !app.accept_ws(&path) { - warn!("No ws there!"); + warn!( + "Socket {}: client tried to access unknown ws path : {}", + app.socket_name(), + path + ); write!( &mut head_buf, "{}\r\n\r\n", @@ -216,7 +210,8 @@ pub async fn listen_task< &mut head_buf, "\r\n\ Content-Type: text/{}\r\n\ - Content-Length: {}\r\n", + Content-Length: {}\r\n\ + Connection: close\r\n", res_type, c.len() )?; @@ -229,8 +224,7 @@ pub async fn listen_task< let res_content = match res_content { Ok(rc) => rc, Err(e) => { - warn!("res buffer write error : {}", Debug2Format(&e)); - break; + break 'connection Err(anyhow!("Res buffer write error : {:?}", e)); } }; @@ -243,23 +237,37 @@ pub async fn listen_task< } }; - if let Err(e) = w { - warn!("write error: {:?}", e); - break; + if let Err(_) = w { + break 'connection Ok("connection reset"); }; if is_ws { - break; + break 'connection Ok(""); + } + }; + match r { + Ok("") => { + info!("Socket {}: Closing connection", app.socket_name()); + } + Ok(msg) => { + info!("Socket {}: Closing connection ({})", app.socket_name(), msg); + } + Err(e) => { + warn!( + "Socket {}: Closing connection ({})", + app.socket_name(), + Display2Format(&e) + ); } } if is_ws { let mut buf = buf.into_bytes(); - unwrap!(buf.resize_default(BUF_LEN)); + buf.resize_default(BUF_LEN).unwrap(); app.handle_ws::( &path, Ws::new( &mut socket, - &mut unwrap!(buf.into_array()), + &mut buf.into_array().unwrap(), app.socket_name(), ), ) diff --git a/src/socket/ws.rs b/src/socket/ws.rs index 8ab98d8..70fec6e 100644 --- a/src/socket/ws.rs +++ b/src/socket/ws.rs @@ -1,5 +1,6 @@ use core::str::from_utf8; +use anyhow::{Result, anyhow}; use embassy_net::tcp::{TcpReader, TcpSocket, TcpWriter}; use embassy_time::Instant; use embedded_io_async::ReadReady; @@ -49,31 +50,35 @@ struct WsTx<'a> { socket: TcpWriter<'a>, } impl<'a> WsTx<'a> { - pub async fn send_with Result>( + pub async fn send_with Result>( &mut self, msg_code: u8, f: F, - ) -> Result<(), ()> { + ) -> Result<()> { if self.send_with_no_flush(msg_code, &f).await.is_err() { - self.socket.flush().await.map_err(|_| ())?; + self.socket + .flush() + .await + .map_err(|_| anyhow!("connection reset"))?; self.send_with_no_flush(msg_code, f).await } else { Ok(()) } } - pub async fn send_with_no_flush Result>( + pub async fn send_with_no_flush Result>( &mut self, msg_code: u8, f: F, - ) -> Result<(), ()> { + ) -> Result<()> { self.socket .write_with(|buf| { if buf.len() < 6 { - return (0, Err(())); + return (0, Err(anyhow!("buffer too small"))); } buf[0] = 0b1000_0000 | msg_code; - let Ok(n) = f(&mut buf[4..]) else { - return (0, Err(())); + let n = match f(&mut buf[4..]) { + Ok(n) => n, + Err(e) => return (0, Err(e)), }; if n < 126 { buf[1] = n as u8; @@ -85,13 +90,13 @@ impl<'a> WsTx<'a> { } }) .await - .map_err(|_| ())? + .map_err(|_| anyhow!("connection reset"))? } - pub async fn send<'m>(&mut self, msg: WsMsg<'m>) -> Result<(), ()> { + pub async fn send<'m>(&mut self, msg: WsMsg<'m>) -> Result<()> { self.send_with(msg.code(), |buf| { let msg = msg.as_bytes(); if buf.len() < msg.len() { - Err(()) + Err(anyhow!("buffer smaller than message")) } else { buf[..msg.len()].copy_from_slice(msg); Ok(msg.len()) @@ -99,9 +104,9 @@ impl<'a> WsTx<'a> { }) .await } - pub async fn send_json(&mut self, msg: &T) -> Result<(), ()> { + pub async fn send_json(&mut self, msg: &T) -> Result<()> { self.send_with(WsMsg::TEXT, |buf| { - serde_json_core::to_slice(msg, buf).map_err(|_| ()) + serde_json_core::to_slice(msg, buf).map_err(|e| e.into()) }) .await } @@ -128,23 +133,21 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> { } } /// Do this often to respond to pings - pub async fn rcv(&mut self) -> Result, ()> { + pub async fn rcv(&mut self) -> Result> { let n = match self.rx.msg_in_buf.take() { Some(n) => { defmt::assert!(n.0 + n.1 <= self.rx.buf.len()); self.rx.buf.copy_within(n.0..n.0 + n.1, 0); if unwrap!(self.rx.socket.read_ready()) { - let n_rcv = match self.rx.socket.read(&mut self.rx.buf[n.1..]).await { - Ok(0) => { - warn!("read EOF"); - return Err(()); - } - Ok(n) => n, - Err(e) => { - warn!("Socket {}: read error: {:?}", self.name, e); - return Err(()); - } - }; + let n_rcv = self + .rx + .socket + .read(&mut self.rx.buf[n.1..]) + .await + .map_err(|_| anyhow!("connection reset"))?; + if n_rcv == 0 { + return Err(anyhow!("read EOF")); + } n.1 + n_rcv } else { n.1 @@ -152,16 +155,17 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> { } None => { if unwrap!(self.rx.socket.read_ready()) { - match self.rx.socket.read(self.rx.buf).await { - Ok(0) => { - warn!("read EOF"); - return Err(()); - } - Ok(n) => n, - Err(e) => { - warn!("Socket {}: read error: {:?}", self.name, e); - return Err(()); + match self + .rx + .socket + .read(self.rx.buf) + .await + .map_err(|_| anyhow!("connection reset"))? + { + 0 => { + return Err(anyhow!("read EOF")); } + n => n, } } else { return Ok(None); @@ -170,46 +174,37 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> { }; if self.rx.buf[0] & 0b1000_0000 == 0 { - warn!("Fragmented ws messages are not supported!"); - return Err(()); + return Err(anyhow!("fragmented ws message")); } if self.rx.buf[0] & 0b0111_0000 != 0 { - warn!( - "Reserved ws bits are set : {}", + return Err(anyhow!( + "reserved ws bits are set {}", (self.rx.buf[0] >> 4) & 0b0111 - ); - return Err(()); + )); } let (length, n_after_length) = match self.rx.buf[1] & 0b0111_1111 { 126 => ( u64::from_le_bytes([0, 0, 0, 0, 0, 0, self.rx.buf[2], self.rx.buf[3]]), 4, ), - 127 => ( - u64::from_le_bytes(self.rx.buf[2..10].try_into().unwrap()), - 10, - ), + 127 => (u64::from_le_bytes(self.rx.buf[2..10].try_into()?), 10), l => (l as u64, 2), }; if length > 512 { - warn!("ws payload bigger than 512!"); - return Err(()); + return Err(anyhow!("ws payload bigger than 512!")); } let content = if self.rx.buf[1] & 0b1000_0000 != 0 { // masked message if n_after_length + 4 + length as usize > n { - warn!("ws payload smaller than length"); - return Err(()); + return Err(anyhow!("ws payload smaller than length")); } - let mask_key: [u8; 4] = self.rx.buf[n_after_length..n_after_length + 4] - .try_into() - .unwrap(); + let mask_key: [u8; 4] = self.rx.buf[n_after_length..n_after_length + 4].try_into()?; for (i, x) in self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize] .iter_mut() .enumerate() { - *x ^= unwrap!(mask_key.get(i % 4)); + *x ^= mask_key[i % 4]; } if n_after_length + 4 + (length as usize) < n { self.rx.msg_in_buf = Some(( @@ -220,8 +215,7 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> { &self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize] } else { if n_after_length + length as usize > n { - warn!("ws payload smaller than length"); - return Err(()); + return Err(anyhow!("ws payload smaller than length")); } if n_after_length + (length as usize) < n { self.rx.msg_in_buf = Some(( @@ -235,7 +229,7 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> { match self.rx.buf[0] & 0b0000_1111 { // Text message 1 => { - let content = from_utf8(&content).map_err(|_| ())?; + let content = from_utf8(&content)?; Ok(Some(WsMsg::Text(content))) } // Bytes @@ -253,30 +247,30 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> { } } } - pub async fn send(&mut self, msg: WsMsg<'_>) -> Result<(), ()> { + pub async fn send(&mut self, msg: WsMsg<'_>) -> Result<()> { self.tx.send(msg).await?; self.last_msg = Instant::now(); Ok(()) } - pub async fn send_json(&mut self, msg: &T) -> Result<(), ()> { + pub async fn send_json(&mut self, msg: &T) -> Result<()> { self.tx.send_json(msg).await?; self.last_msg = Instant::now(); Ok(()) } - pub async fn send_with Result>( + pub async fn send_with Result>( &mut self, msg_code: u8, f: F, - ) -> Result<(), ()> { + ) -> Result<()> { self.tx.send_with(msg_code, f).await?; self.last_msg = Instant::now(); Ok(()) } - pub async fn send_with_no_flush Result>( + pub async fn send_with_no_flush Result>( &mut self, msg_code: u8, f: F, - ) -> Result<(), ()> { + ) -> Result<()> { self.tx.send_with_no_flush(msg_code, f).await?; self.last_msg = Instant::now(); Ok(())