refactor errors to anyhow + improve chat

This commit is contained in:
Arkitu 2025-09-07 21:52:36 +02:00
parent 01fc28e1e7
commit 379ce3d010
9 changed files with 351 additions and 199 deletions

65
Cargo.lock generated
View File

@ -11,6 +11,12 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "anyhow"
version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
[[package]] [[package]]
name = "arrayvec" name = "arrayvec"
version = "0.7.6" version = "0.7.6"
@ -56,6 +62,12 @@ dependencies = [
"rustc_version", "rustc_version",
] ]
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.22.1" version = "0.22.1"
@ -146,6 +158,12 @@ dependencies = [
"unicode-width", "unicode-width",
] ]
[[package]]
name = "const-default"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b396d1f76d455557e1218ec8066ae14bba60b4b36ecd55577ba979f5db7ecaa"
[[package]] [[package]]
name = "cortex-m" name = "cortex-m"
version = "0.7.7" version = "0.7.7"
@ -605,6 +623,18 @@ dependencies = [
"embedded-io-async", "embedded-io-async",
] ]
[[package]]
name = "embedded-alloc"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f2de9133f68db0d4627ad69db767726c99ff8585272716708227008d3f1bddd"
dependencies = [
"const-default",
"critical-section",
"linked_list_allocator",
"rlsf",
]
[[package]] [[package]]
name = "embedded-hal" name = "embedded-hal"
version = "0.2.7" version = "0.2.7"
@ -934,6 +964,12 @@ version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "linked_list_allocator"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9afa463f5405ee81cdb9cc2baf37e08ec7e4c8209442b5d72c04cfb2cd6e6286"
[[package]] [[package]]
name = "litrs" name = "litrs"
version = "0.4.2" version = "0.4.2"
@ -1113,7 +1149,8 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
name = "pico-website" name = "pico-website"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"base64", "anyhow",
"base64 0.22.1",
"cortex-m", "cortex-m",
"cortex-m-rt", "cortex-m-rt",
"cyw43", "cyw43",
@ -1127,6 +1164,7 @@ dependencies = [
"embassy-rp", "embassy-rp",
"embassy-sync 0.7.1", "embassy-sync 0.7.1",
"embassy-time", "embassy-time",
"embedded-alloc",
"embedded-io-async", "embedded-io-async",
"heapless", "heapless",
"log", "log",
@ -1355,6 +1393,18 @@ dependencies = [
"portable-atomic-util", "portable-atomic-util",
] ]
[[package]]
name = "rlsf"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "222fb240c3286247ecdee6fa5341e7cdad0ffdf8e7e401d9937f2d58482a20bf"
dependencies = [
"cfg-if",
"const-default",
"libc",
"svgbobdoc",
]
[[package]] [[package]]
name = "rp-pac" name = "rp-pac"
version = "7.0.0" version = "7.0.0"
@ -1560,6 +1610,19 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "svgbobdoc"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2c04b93fc15d79b39c63218f15e3fdffaa4c227830686e3b7c5f41244eb3e50"
dependencies = [
"base64 0.13.1",
"proc-macro2",
"quote",
"syn 1.0.109",
"unicode-width",
]
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.109" version = "1.0.109"

View File

@ -69,3 +69,6 @@ ringbuf = { version = "*", default-features = false, features = [
percent-encoding = { version = "*", default-features = false } percent-encoding = { version = "*", default-features = false }
sha1 = { version = "*", default-features = false } sha1 = { version = "*", default-features = false }
base64 = { version = "*", default-features = false } base64 = { version = "*", default-features = false }
anyhow = { version = "*", default-features = false }
embedded-alloc = "*"

View File

@ -1,2 +1,3 @@
[toolchain] [toolchain]
channel = "nightly-2025-03-18" # channel = "nightly-2025-03-18"
channel = "nightly"

View File

@ -41,7 +41,8 @@ ws.onmessage = (event) => {
document.getElementById("send").onsubmit = (event) => { document.getElementById("send").onsubmit = (event) => {
event.preventDefault(); event.preventDefault();
// console.log(event, document.getElementById("sendcontent").value); // console.log(event, document.getElementById("sendcontent").value);
let timestamp = Date.now();
let content = document.getElementById("sendcontent"); let content = document.getElementById("sendcontent");
ws.send("send " + content.value); ws.send("send " + timestamp.toString() + " " + content.value);
content.value = ""; content.value = "";
}; };

View File

@ -1,10 +1,11 @@
use core::sync::atomic::Ordering; use core::{str::from_utf8_unchecked, sync::atomic::Ordering};
use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex}; use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex};
use embassy_time::{Duration, Timer}; use embassy_time::{Duration, Timer};
use heapless::String; use heapless::String;
// use log::{info, warn}; // use log::{info, warn};
// use pico_website::unimplemented; // use pico_website::unimplemented;
use anyhow::{Result, anyhow};
use defmt::*; use defmt::*;
use portable_atomic::AtomicUsize; use portable_atomic::AtomicUsize;
use serde::Serialize; use serde::Serialize;
@ -19,6 +20,7 @@ const MSG_MAX_SIZE: usize = 500;
struct Msg<'a> { struct Msg<'a> {
id: usize, id: usize,
author: u8, author: u8,
timestamp: u64,
content: &'a str, content: &'a str,
} }
@ -26,8 +28,8 @@ const MSGS_SIZE: usize = 100000;
const _: () = core::assert!(MSGS_SIZE > MSG_MAX_SIZE); const _: () = core::assert!(MSGS_SIZE > MSG_MAX_SIZE);
#[derive(Debug)] #[derive(Debug)]
struct Msgs { struct Msgs {
/// * Memory layout with sizes in bytes : ...|content: len|len: 2|author+1: 1|... /// * Memory layout with sizes in bytes : ...|content: len|len: 2|delimiter(255)| other_msg/pad(0)...
/// * `author=0` means theres no message, it's just padding and should be skipped. /// * content = author: 1 | timestamp: 8 | content: len-5
/// * No message is splitted /// * No message is splitted
inner: [u8; MSGS_SIZE], inner: [u8; MSGS_SIZE],
/// next byte index /// next byte index
@ -42,20 +44,24 @@ impl Msgs {
next_msg: 0, next_msg: 0,
} }
} }
fn push(&mut self, author: u8, content: &str) { fn push(&mut self, author: u8, timestamp: u64, content: &str) {
if self.head + content.len() + 3 >= MSGS_SIZE { let len = 1 + 8 + content.len();
if self.head + len + 3 >= MSGS_SIZE {
self.inner[self.head..].fill(0); self.inner[self.head..].fill(0);
self.head = 0 self.head = 0
} }
self.inner[self.head..self.head + content.len()].copy_from_slice(content.as_bytes()); self.inner[self.head] = author;
self.head += content.len(); self.inner[self.head + 1..self.head + 9].copy_from_slice(&timestamp.to_le_bytes());
self.inner[self.head..self.head + 2].copy_from_slice(&(content.len() as u16).to_le_bytes()); self.inner[self.head + 9..self.head + 9 + content.len()]
self.inner[self.head + 2] = author + 1; .copy_from_slice(content.as_bytes());
self.head += 3; self.inner[self.head + len..self.head + len + 2]
.copy_from_slice(&(len as u16).to_le_bytes());
self.inner[self.head + len + 2] = u8::MAX;
self.head += len + 3;
self.next_msg += 1; self.next_msg += 1;
} }
/// Iter messages from present to past /// Iter messages from present to past
fn iter(&self) -> MsgsIter { fn iter<'a>(&'a self) -> MsgsIter<'a> {
if self.head == 0 { if self.head == 0 {
MsgsIter { MsgsIter {
msgs: self, msgs: self,
@ -93,41 +99,86 @@ impl<'a> Iterator for MsgsIter<'a> {
self.head = MSGS_SIZE; self.head = MSGS_SIZE;
} }
let above = self.head > self.msgs.head; let above = self.head > self.msgs.head;
if above && self.head < self.msgs.head + 3 { if above && self.head <= self.msgs.head + 12 {
self.finished = true; self.finished = true;
return None; return None;
} }
let author = self.msgs.inner[self.head - 1]; match self.msgs.inner[self.head - 1] {
self.head -= 1; // Skip padding
if author == 0 { 0 => {
return self.next(); self.head -= 1;
self.next()
}
u8::MAX => {
let len = u16::from_le_bytes([
self.msgs.inner[self.head - 3],
self.msgs.inner[self.head - 2],
]) as usize;
self.head -= 3 + len;
let author = self.msgs.inner[self.head];
let timestamp = u64::from_le_bytes([
self.msgs.inner[self.head + 1],
self.msgs.inner[self.head + 2],
self.msgs.inner[self.head + 3],
self.msgs.inner[self.head + 4],
self.msgs.inner[self.head + 5],
self.msgs.inner[self.head + 6],
self.msgs.inner[self.head + 7],
self.msgs.inner[self.head + 8],
]);
let content = unsafe {
from_utf8_unchecked(&self.msgs.inner[self.head + 9..self.head + len])
};
let id = self.current_id;
if self.current_id == 0 {
self.finished = true;
} else {
self.current_id -= 1;
}
Some(Msg {
id,
author,
timestamp,
content,
})
}
_ => core::unreachable!(),
} }
let author = author - 1; // if self.msgs.inner[self.head - 1] == 0 {
let len = u16::from_le_bytes([ // return self.next();
self.msgs.inner[self.head - 2], // }
self.msgs.inner[self.head - 1],
]) as usize;
self.head -= 2;
let content = // let author = self.msgs.inner[self.head - 1];
unsafe { str::from_utf8_unchecked(&self.msgs.inner[self.head - len..self.head]) }; // self.head -= 1;
self.head -= len; // if author == 0 {
if above && self.head < self.msgs.head { // return self.next();
self.finished = true; // }
return None; // let author = author - 1;
} // let len = u16::from_le_bytes([
let id = self.current_id; // self.msgs.inner[self.head - 2],
if self.current_id == 0 { // self.msgs.inner[self.head - 1],
self.finished = true; // ]) as usize;
} else { // self.head -= 2;
self.current_id -= 1;
}
Some(Msg { // let content =
id, // unsafe { str::from_utf8_unchecked(&self.msgs.inner[self.head - len..self.head]) };
author, // self.head -= len;
content, // if above && self.head < self.msgs.head {
}) // self.finished = true;
// return None;
// }
// let id = self.current_id;
// if self.current_id == 0 {
// self.finished = true;
// } else {
// self.current_id -= 1;
// }
// Some(Msg {
// id,
// author,
// content,
// })
} }
} }
@ -191,50 +242,68 @@ impl App for ChatApp {
self.usernames_version = 0; self.usernames_version = 0;
self.next_msg = 0; self.next_msg = 0;
Timer::after_millis(500).await; Timer::after_millis(500).await;
let r: Result<(), ()> = try { let r: Result<()> = try {
loop { 'ws: {
Timer::after_millis(1).await; loop {
{ Timer::after_millis(1).await;
let uv = USERNAMES_VERSION.load(Ordering::Relaxed); {
if self.usernames_version < uv { let uv = USERNAMES_VERSION.load(Ordering::Relaxed);
ws.send_json(&(*USERNAMES.lock().await)).await?; if self.usernames_version < uv {
} ws.send_json(&(*USERNAMES.lock().await)).await?;
self.usernames_version = uv;
}
{
let msgs = MSGS.lock().await;
for m in msgs.iter() {
if m.id >= self.next_msg {
ws.send_json(&m).await?;
} }
self.usernames_version = uv;
} }
self.next_msg = msgs.next_msg; {
} let msgs = MSGS.lock().await;
if ws.last_msg.elapsed() >= Duration::from_secs(5) { for m in msgs.iter() {
ws.send(WsMsg::Ping(&[])).await?; if m.id >= self.next_msg {
} ws.send_json(&m).await?;
while let Some(r) = ws.rcv().await? {
info!("{:?}", r);
if let WsMsg::Text(r) = r {
if r.starts_with("send ") {
if r.len() > 5 + MSG_MAX_SIZE {
warn!("Message too long! (len={})", r.len() - 5);
return;
} }
{ }
MSGS.lock() self.next_msg = msgs.next_msg;
.await }
.push(self.id, r.get(5..).unwrap_or_default()); if ws.last_msg.elapsed() >= Duration::from_secs(5) {
ws.send(WsMsg::Ping(&[])).await?;
}
while let Some(r) = ws.rcv().await? {
info!("{:?}", r);
if let WsMsg::Text(r) = r {
match r.split_once(' ') {
Some(("send", m)) => {
let (t, msg) = m
.split_once(' ')
.ok_or(anyhow!("message cannot be splitted with ' '"))?;
info!("{}", t);
let timestamp = t.parse::<u64>()?;
if msg.len() > MSG_MAX_SIZE {
warn!("Message too long! (len={})", r.len() - 5);
break 'ws;
}
MSGS.lock().await.push(self.id, timestamp, msg);
}
_ => {}
} }
// if r.starts_with("send ") {
// if r.len() > 5 + MSG_MAX_SIZE {
// warn!("Message too long! (len={})", r.len() - 5);
// break 'ws;
// }
// {
// MSGS.lock()
// .await
// .push(self.id, r.get(5..).unwrap_or_default());
// }
// }
} }
} }
} }
} }
}; };
if r.is_err() { if let Err(e) = r {
warn!( warn!(
"Socket {}: error in ws, terminating connection", "Socket {}: error in ws\n{}",
self.socket_name() self.socket_name(),
Display2Format(&e)
); );
} }
} }

View File

@ -1,3 +1,4 @@
use anyhow::Result;
use core::ops::Not; use core::ops::Not;
use core::str::from_utf8_unchecked; use core::str::from_utf8_unchecked;
use embassy_sync::blocking_mutex::raw::ThreadModeRawMutex; use embassy_sync::blocking_mutex::raw::ThreadModeRawMutex;
@ -92,7 +93,7 @@ impl App for TttApp {
} }
async fn handle_ws<const BUF_SIZE: usize>(&mut self, _path: &str, mut ws: Ws<'_, BUF_SIZE>) { async fn handle_ws<const BUF_SIZE: usize>(&mut self, _path: &str, mut ws: Ws<'_, BUF_SIZE>) {
Timer::after_millis(500).await; Timer::after_millis(500).await;
let r: Result<(), ()> = try { let r: Result<()> = try {
loop { loop {
Timer::after_millis(1).await; Timer::after_millis(1).await;
let Ok(mut game) = GAME.try_lock() else { let Ok(mut game) = GAME.try_lock() else {
@ -140,10 +141,11 @@ impl App for TttApp {
} }
} }
}; };
if r.is_err() { if let Err(e) = r {
warn!( warn!(
"Socket {}: error in ws, terminating connection", "Socket {}: error in ws\n{}",
self.socket_name() self.socket_name(),
Display2Format(&e)
); );
} }
} }

View File

@ -6,9 +6,9 @@
#![feature(try_blocks)] #![feature(try_blocks)]
#![feature(impl_trait_in_bindings)] #![feature(impl_trait_in_bindings)]
#![feature(array_repeat)] #![feature(array_repeat)]
#![feature(generic_arg_infer)]
#![feature(async_iterator)] #![feature(async_iterator)]
use core::mem::MaybeUninit;
#[cfg(feature = "wifi-connect")] #[cfg(feature = "wifi-connect")]
use core::net::Ipv4Addr; use core::net::Ipv4Addr;
@ -19,16 +19,15 @@ use embassy_net::{Config, StackResources};
use embassy_rp::bind_interrupts; use embassy_rp::bind_interrupts;
use embassy_rp::clocks::RoscRng; use embassy_rp::clocks::RoscRng;
use embassy_rp::gpio::{Level, Output}; use embassy_rp::gpio::{Level, Output};
use embassy_rp::peripherals::USB;
use embassy_rp::peripherals::{DMA_CH0, PIO0}; use embassy_rp::peripherals::{DMA_CH0, PIO0};
use embassy_rp::pio::{InterruptHandler as PioInterruptHandler, Pio}; use embassy_rp::pio::{InterruptHandler as PioInterruptHandler, Pio};
use embassy_rp::usb::{Driver, InterruptHandler as UsbInterruptHandler}; use embedded_alloc::LlffHeap as Heap;
// use log::info;
// use pico_website::unwrap;
use rand_core::RngCore;
use static_cell::StaticCell; use static_cell::StaticCell;
use {defmt_rtt as _, panic_probe as _}; use {defmt_rtt as _, panic_probe as _};
#[global_allocator]
static HEAP: Heap = Heap::empty();
#[cfg(feature = "dhcp")] #[cfg(feature = "dhcp")]
mod dhcp; mod dhcp;
@ -39,7 +38,7 @@ mod apps;
mod socket; mod socket;
bind_interrupts!(struct Irqs { bind_interrupts!(struct Irqs {
USBCTRL_IRQ => UsbInterruptHandler<USB>; // USBCTRL_IRQ => UsbInterruptHandler<USB>;
PIO0_IRQ_0 => PioInterruptHandler<PIO0>; PIO0_IRQ_0 => PioInterruptHandler<PIO0>;
}); });
@ -62,8 +61,20 @@ async fn net_task(mut runner: embassy_net::Runner<'static, cyw43::NetDriver<'sta
#[embassy_executor::main] #[embassy_executor::main]
async fn main(spawner: Spawner) { async fn main(spawner: Spawner) {
// Init heap
// const HEAP_SIZE: usize = 4096;
// static mut HEAP_MEM: [MaybeUninit<u8>; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE];
// let heap: Heap = Heap::empty();
// unsafe { heap.init(&raw mut HEAP_MEM as usize, HEAP_SIZE) }
{
use core::mem::MaybeUninit;
const HEAP_SIZE: usize = 1024;
static mut HEAP_MEM: [MaybeUninit<u8>; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE];
unsafe { HEAP.init(&raw mut HEAP_MEM as usize, HEAP_SIZE) }
}
let p = embassy_rp::init(Default::default()); let p = embassy_rp::init(Default::default());
let driver = Driver::new(p.USB, Irqs); // let driver = Driver::new(p.USB, Irqs);
// spawner.spawn(logger_task(driver)).unwrap(); // spawner.spawn(logger_task(driver)).unwrap();
let mut rng = RoscRng; let mut rng = RoscRng;

View File

@ -1,3 +1,4 @@
use anyhow::{Result, anyhow};
use base64::prelude::*; use base64::prelude::*;
use core::fmt::Write; use core::fmt::Write;
@ -51,7 +52,8 @@ pub async fn listen_task<
loop { loop {
Timer::after_secs(0).await; Timer::after_secs(0).await;
let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer);
socket.set_timeout(Some(Duration::from_secs(10))); // socket.set_timeout(Some(Duration::from_secs(5)));
// socket.set_keep_alive(Some(Duration::from_secs(1)));
info!("Socket {}: Listening on TCP:{}...", app.socket_name(), port); info!("Socket {}: Listening on TCP:{}...", app.socket_name(), port);
if let Err(e) = socket.accept(port).await { if let Err(e) = socket.accept(port).await {
@ -70,34 +72,31 @@ pub async fn listen_task<
let mut path = String::<PATH_LEN>::new(); let mut path = String::<PATH_LEN>::new();
let mut request_type = HttpRequestType::Get; let mut request_type = HttpRequestType::Get;
let mut is_ws = false; let mut is_ws = false;
loop { let r: Result<&str> = 'connection: loop {
Timer::after_secs(0).await; Timer::after_secs(0).await;
match socket match socket
.read_with(|msg| { .read_with(|msg| {
let (headers, content) = match from_utf8(msg) { let (headers, content) = match from_utf8(msg) {
Ok(b) => { Ok(b) => match b.split_once("\r\n\r\n") {
info!("{}", b); Some(t) => t,
match b.split_once("\r\n\r\n") { None => (b, ""),
Some(t) => t, },
None => (b, ""), Err(e) => {
} return (0, Err(e.into()));
}
Err(_) => {
warn!("Non utf8 http request");
return (0, Err(()));
} }
}; };
buf.clear(); buf.clear();
if let Err(_) = buf.push_str(content) { if buf.push_str(content).is_err() {
warn!("Received content is bigger than maximum content!"); return (
return (0, Err(())); 0,
Err(anyhow!("Received content is bigger than maximum content!")),
);
} }
let mut hl = headers.lines(); let mut hl = headers.lines();
match hl.next() { match hl.next() {
None => { None => {
warn!("Empty request"); return (0, Err(anyhow!("Empty request")));
return (0, Err(()));
} }
Some(l1) => { Some(l1) => {
let mut l1 = l1.split(' '); let mut l1 = l1.split(' ');
@ -105,24 +104,20 @@ pub async fn listen_task<
Some("GET") => HttpRequestType::Get, Some("GET") => HttpRequestType::Get,
Some("POST") => HttpRequestType::Post, Some("POST") => HttpRequestType::Post,
Some(t) => { Some(t) => {
warn!("Unknown request type : {}", t); return (0, Err(anyhow!("Unknown request type : {}", t)));
return (0, Err(()));
} }
None => { None => {
warn!("No request type"); return (0, Err(anyhow!("No request type")));
return (0, Err(()));
} }
}; };
path.clear(); path.clear();
if let Err(_) = path.push_str(match l1.next() { if let Err(_) = path.push_str(match l1.next() {
Some(path) => path, Some(path) => path,
None => { None => {
warn!("No path"); return (0, Err(anyhow!("No path")));
return (0, Err(()));
} }
}) { }) {
warn!("Path is too big!"); return (0, Err(anyhow!("Path is too big!")));
return (0, Err(()));
} }
} }
}; };
@ -143,8 +138,7 @@ pub async fn listen_task<
} }
} }
let Some(host) = host else { let Some(host) = host else {
warn!("No host!"); return (0, Err(anyhow!("No host!")));
return (0, Err(()));
}; };
info!( info!(
"Socket {}: {:?}{} request for {}{}", "Socket {}: {:?}{} request for {}{}",
@ -157,17 +151,14 @@ pub async fn listen_task<
buf.clear(); buf.clear();
if is_ws { if is_ws {
let Some(key) = ws_key else { let Some(key) = ws_key else {
warn!("No ws key!"); return (0, Err(anyhow!("No ws key!")));
return (0, Err(()));
}; };
if let Err(_) = buf.push_str(key) { if buf.push_str(key).is_err() {
warn!("Ws key is too long!"); return (0, Err(anyhow!("Ws key is too long!")));
return (0, Err(()));
} }
} else { } else {
if let Err(_) = buf.push_str(content) { if buf.push_str(content).is_err() {
warn!("Content is too long!"); return (0, Err(anyhow!("Content is too long!")));
return (0, Err(()));
} }
} }
(msg.len(), Ok(())) (msg.len(), Ok(()))
@ -175,10 +166,9 @@ pub async fn listen_task<
.await .await
{ {
Ok(Ok(())) => {} Ok(Ok(())) => {}
Ok(Err(())) => break, Ok(Err(e)) => break 'connection Err(e),
Err(e) => { Err(e) => {
warn!("Error while receiving : {:?}", e); break 'connection Ok("connection reset");
break;
} }
}; };
@ -186,7 +176,11 @@ pub async fn listen_task<
let res_content: Result<Option<Content>, core::fmt::Error> = try { let res_content: Result<Option<Content>, core::fmt::Error> = try {
if is_ws { if is_ws {
if !app.accept_ws(&path) { if !app.accept_ws(&path) {
warn!("No ws there!"); warn!(
"Socket {}: client tried to access unknown ws path : {}",
app.socket_name(),
path
);
write!( write!(
&mut head_buf, &mut head_buf,
"{}\r\n\r\n", "{}\r\n\r\n",
@ -216,7 +210,8 @@ pub async fn listen_task<
&mut head_buf, &mut head_buf,
"\r\n\ "\r\n\
Content-Type: text/{}\r\n\ Content-Type: text/{}\r\n\
Content-Length: {}\r\n", Content-Length: {}\r\n\
Connection: close\r\n",
res_type, res_type,
c.len() c.len()
)?; )?;
@ -229,8 +224,7 @@ pub async fn listen_task<
let res_content = match res_content { let res_content = match res_content {
Ok(rc) => rc, Ok(rc) => rc,
Err(e) => { Err(e) => {
warn!("res buffer write error : {}", Debug2Format(&e)); break 'connection Err(anyhow!("Res buffer write error : {:?}", e));
break;
} }
}; };
@ -243,23 +237,37 @@ pub async fn listen_task<
} }
}; };
if let Err(e) = w { if let Err(_) = w {
warn!("write error: {:?}", e); break 'connection Ok("connection reset");
break;
}; };
if is_ws { if is_ws {
break; break 'connection Ok("");
}
};
match r {
Ok("") => {
info!("Socket {}: Closing connection", app.socket_name());
}
Ok(msg) => {
info!("Socket {}: Closing connection ({})", app.socket_name(), msg);
}
Err(e) => {
warn!(
"Socket {}: Closing connection ({})",
app.socket_name(),
Display2Format(&e)
);
} }
} }
if is_ws { if is_ws {
let mut buf = buf.into_bytes(); let mut buf = buf.into_bytes();
unwrap!(buf.resize_default(BUF_LEN)); buf.resize_default(BUF_LEN).unwrap();
app.handle_ws::<BUF_LEN>( app.handle_ws::<BUF_LEN>(
&path, &path,
Ws::new( Ws::new(
&mut socket, &mut socket,
&mut unwrap!(buf.into_array()), &mut buf.into_array().unwrap(),
app.socket_name(), app.socket_name(),
), ),
) )

View File

@ -1,5 +1,6 @@
use core::str::from_utf8; use core::str::from_utf8;
use anyhow::{Result, anyhow};
use embassy_net::tcp::{TcpReader, TcpSocket, TcpWriter}; use embassy_net::tcp::{TcpReader, TcpSocket, TcpWriter};
use embassy_time::Instant; use embassy_time::Instant;
use embedded_io_async::ReadReady; use embedded_io_async::ReadReady;
@ -49,31 +50,35 @@ struct WsTx<'a> {
socket: TcpWriter<'a>, socket: TcpWriter<'a>,
} }
impl<'a> WsTx<'a> { impl<'a> WsTx<'a> {
pub async fn send_with<F: Fn(&mut [u8]) -> Result<usize, ()>>( pub async fn send_with<F: Fn(&mut [u8]) -> Result<usize>>(
&mut self, &mut self,
msg_code: u8, msg_code: u8,
f: F, f: F,
) -> Result<(), ()> { ) -> Result<()> {
if self.send_with_no_flush(msg_code, &f).await.is_err() { if self.send_with_no_flush(msg_code, &f).await.is_err() {
self.socket.flush().await.map_err(|_| ())?; self.socket
.flush()
.await
.map_err(|_| anyhow!("connection reset"))?;
self.send_with_no_flush(msg_code, f).await self.send_with_no_flush(msg_code, f).await
} else { } else {
Ok(()) Ok(())
} }
} }
pub async fn send_with_no_flush<F: FnOnce(&mut [u8]) -> Result<usize, ()>>( pub async fn send_with_no_flush<F: FnOnce(&mut [u8]) -> Result<usize>>(
&mut self, &mut self,
msg_code: u8, msg_code: u8,
f: F, f: F,
) -> Result<(), ()> { ) -> Result<()> {
self.socket self.socket
.write_with(|buf| { .write_with(|buf| {
if buf.len() < 6 { if buf.len() < 6 {
return (0, Err(())); return (0, Err(anyhow!("buffer too small")));
} }
buf[0] = 0b1000_0000 | msg_code; buf[0] = 0b1000_0000 | msg_code;
let Ok(n) = f(&mut buf[4..]) else { let n = match f(&mut buf[4..]) {
return (0, Err(())); Ok(n) => n,
Err(e) => return (0, Err(e)),
}; };
if n < 126 { if n < 126 {
buf[1] = n as u8; buf[1] = n as u8;
@ -85,13 +90,13 @@ impl<'a> WsTx<'a> {
} }
}) })
.await .await
.map_err(|_| ())? .map_err(|_| anyhow!("connection reset"))?
} }
pub async fn send<'m>(&mut self, msg: WsMsg<'m>) -> Result<(), ()> { pub async fn send<'m>(&mut self, msg: WsMsg<'m>) -> Result<()> {
self.send_with(msg.code(), |buf| { self.send_with(msg.code(), |buf| {
let msg = msg.as_bytes(); let msg = msg.as_bytes();
if buf.len() < msg.len() { if buf.len() < msg.len() {
Err(()) Err(anyhow!("buffer smaller than message"))
} else { } else {
buf[..msg.len()].copy_from_slice(msg); buf[..msg.len()].copy_from_slice(msg);
Ok(msg.len()) Ok(msg.len())
@ -99,9 +104,9 @@ impl<'a> WsTx<'a> {
}) })
.await .await
} }
pub async fn send_json<T: serde::Serialize>(&mut self, msg: &T) -> Result<(), ()> { pub async fn send_json<T: serde::Serialize>(&mut self, msg: &T) -> Result<()> {
self.send_with(WsMsg::TEXT, |buf| { self.send_with(WsMsg::TEXT, |buf| {
serde_json_core::to_slice(msg, buf).map_err(|_| ()) serde_json_core::to_slice(msg, buf).map_err(|e| e.into())
}) })
.await .await
} }
@ -128,23 +133,21 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> {
} }
} }
/// Do this often to respond to pings /// Do this often to respond to pings
pub async fn rcv(&mut self) -> Result<Option<WsMsg>, ()> { pub async fn rcv(&mut self) -> Result<Option<WsMsg>> {
let n = match self.rx.msg_in_buf.take() { let n = match self.rx.msg_in_buf.take() {
Some(n) => { Some(n) => {
defmt::assert!(n.0 + n.1 <= self.rx.buf.len()); defmt::assert!(n.0 + n.1 <= self.rx.buf.len());
self.rx.buf.copy_within(n.0..n.0 + n.1, 0); self.rx.buf.copy_within(n.0..n.0 + n.1, 0);
if unwrap!(self.rx.socket.read_ready()) { if unwrap!(self.rx.socket.read_ready()) {
let n_rcv = match self.rx.socket.read(&mut self.rx.buf[n.1..]).await { let n_rcv = self
Ok(0) => { .rx
warn!("read EOF"); .socket
return Err(()); .read(&mut self.rx.buf[n.1..])
} .await
Ok(n) => n, .map_err(|_| anyhow!("connection reset"))?;
Err(e) => { if n_rcv == 0 {
warn!("Socket {}: read error: {:?}", self.name, e); return Err(anyhow!("read EOF"));
return Err(()); }
}
};
n.1 + n_rcv n.1 + n_rcv
} else { } else {
n.1 n.1
@ -152,16 +155,17 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> {
} }
None => { None => {
if unwrap!(self.rx.socket.read_ready()) { if unwrap!(self.rx.socket.read_ready()) {
match self.rx.socket.read(self.rx.buf).await { match self
Ok(0) => { .rx
warn!("read EOF"); .socket
return Err(()); .read(self.rx.buf)
} .await
Ok(n) => n, .map_err(|_| anyhow!("connection reset"))?
Err(e) => { {
warn!("Socket {}: read error: {:?}", self.name, e); 0 => {
return Err(()); return Err(anyhow!("read EOF"));
} }
n => n,
} }
} else { } else {
return Ok(None); return Ok(None);
@ -170,46 +174,37 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> {
}; };
if self.rx.buf[0] & 0b1000_0000 == 0 { if self.rx.buf[0] & 0b1000_0000 == 0 {
warn!("Fragmented ws messages are not supported!"); return Err(anyhow!("fragmented ws message"));
return Err(());
} }
if self.rx.buf[0] & 0b0111_0000 != 0 { if self.rx.buf[0] & 0b0111_0000 != 0 {
warn!( return Err(anyhow!(
"Reserved ws bits are set : {}", "reserved ws bits are set {}",
(self.rx.buf[0] >> 4) & 0b0111 (self.rx.buf[0] >> 4) & 0b0111
); ));
return Err(());
} }
let (length, n_after_length) = match self.rx.buf[1] & 0b0111_1111 { let (length, n_after_length) = match self.rx.buf[1] & 0b0111_1111 {
126 => ( 126 => (
u64::from_le_bytes([0, 0, 0, 0, 0, 0, self.rx.buf[2], self.rx.buf[3]]), u64::from_le_bytes([0, 0, 0, 0, 0, 0, self.rx.buf[2], self.rx.buf[3]]),
4, 4,
), ),
127 => ( 127 => (u64::from_le_bytes(self.rx.buf[2..10].try_into()?), 10),
u64::from_le_bytes(self.rx.buf[2..10].try_into().unwrap()),
10,
),
l => (l as u64, 2), l => (l as u64, 2),
}; };
if length > 512 { if length > 512 {
warn!("ws payload bigger than 512!"); return Err(anyhow!("ws payload bigger than 512!"));
return Err(());
} }
let content = if self.rx.buf[1] & 0b1000_0000 != 0 { let content = if self.rx.buf[1] & 0b1000_0000 != 0 {
// masked message // masked message
if n_after_length + 4 + length as usize > n { if n_after_length + 4 + length as usize > n {
warn!("ws payload smaller than length"); return Err(anyhow!("ws payload smaller than length"));
return Err(());
} }
let mask_key: [u8; 4] = self.rx.buf[n_after_length..n_after_length + 4] let mask_key: [u8; 4] = self.rx.buf[n_after_length..n_after_length + 4].try_into()?;
.try_into()
.unwrap();
for (i, x) in self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize] for (i, x) in self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize]
.iter_mut() .iter_mut()
.enumerate() .enumerate()
{ {
*x ^= unwrap!(mask_key.get(i % 4)); *x ^= mask_key[i % 4];
} }
if n_after_length + 4 + (length as usize) < n { if n_after_length + 4 + (length as usize) < n {
self.rx.msg_in_buf = Some(( self.rx.msg_in_buf = Some((
@ -220,8 +215,7 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> {
&self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize] &self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize]
} else { } else {
if n_after_length + length as usize > n { if n_after_length + length as usize > n {
warn!("ws payload smaller than length"); return Err(anyhow!("ws payload smaller than length"));
return Err(());
} }
if n_after_length + (length as usize) < n { if n_after_length + (length as usize) < n {
self.rx.msg_in_buf = Some(( self.rx.msg_in_buf = Some((
@ -235,7 +229,7 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> {
match self.rx.buf[0] & 0b0000_1111 { match self.rx.buf[0] & 0b0000_1111 {
// Text message // Text message
1 => { 1 => {
let content = from_utf8(&content).map_err(|_| ())?; let content = from_utf8(&content)?;
Ok(Some(WsMsg::Text(content))) Ok(Some(WsMsg::Text(content)))
} }
// Bytes // Bytes
@ -253,30 +247,30 @@ impl<'a, const BUF_SIZE: usize> Ws<'a, BUF_SIZE> {
} }
} }
} }
pub async fn send(&mut self, msg: WsMsg<'_>) -> Result<(), ()> { pub async fn send(&mut self, msg: WsMsg<'_>) -> Result<()> {
self.tx.send(msg).await?; self.tx.send(msg).await?;
self.last_msg = Instant::now(); self.last_msg = Instant::now();
Ok(()) Ok(())
} }
pub async fn send_json<T: serde::Serialize>(&mut self, msg: &T) -> Result<(), ()> { pub async fn send_json<T: serde::Serialize>(&mut self, msg: &T) -> Result<()> {
self.tx.send_json(msg).await?; self.tx.send_json(msg).await?;
self.last_msg = Instant::now(); self.last_msg = Instant::now();
Ok(()) Ok(())
} }
pub async fn send_with<F: Fn(&mut [u8]) -> Result<usize, ()>>( pub async fn send_with<F: Fn(&mut [u8]) -> Result<usize>>(
&mut self, &mut self,
msg_code: u8, msg_code: u8,
f: F, f: F,
) -> Result<(), ()> { ) -> Result<()> {
self.tx.send_with(msg_code, f).await?; self.tx.send_with(msg_code, f).await?;
self.last_msg = Instant::now(); self.last_msg = Instant::now();
Ok(()) Ok(())
} }
pub async fn send_with_no_flush<F: FnOnce(&mut [u8]) -> Result<usize, ()>>( pub async fn send_with_no_flush<F: FnOnce(&mut [u8]) -> Result<usize>>(
&mut self, &mut self,
msg_code: u8, msg_code: u8,
f: F, f: F,
) -> Result<(), ()> { ) -> Result<()> {
self.tx.send_with_no_flush(msg_code, f).await?; self.tx.send_with_no_flush(msg_code, f).await?;
self.last_msg = Instant::now(); self.last_msg = Instant::now();
Ok(()) Ok(())