use core::sync::atomic::Ordering; use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex}; use embassy_time::{Duration, Timer}; use heapless::String; // use log::{info, warn}; // use pico_website::unimplemented; use defmt::*; use portable_atomic::AtomicUsize; use serde::Serialize; use crate::{apps::App, socket::ws::WsMsg}; // Must be <= u8::MAX-1; pub const USERS_LEN: u8 = 4; const MSG_MAX_SIZE: usize = 500; #[derive(Debug, Serialize)] struct Msg<'a> { id: usize, author: u8, content: &'a str, } const MSGS_SIZE: usize = 100000; const _: () = core::assert!(MSGS_SIZE > MSG_MAX_SIZE); #[derive(Debug)] struct Msgs { /// * Memory layout with sizes in bytes : ...|content: len|len: 2|author+1: 1|... /// * `author=0` means theres no message, it's just padding and should be skipped. /// * No message is splitted inner: [u8; MSGS_SIZE], /// next byte index head: usize, next_msg: usize, } impl Msgs { const fn default() -> Self { Self { inner: [0; _], head: 0, next_msg: 0, } } fn push(&mut self, author: u8, content: &str) { if self.head + content.len() + 3 >= MSGS_SIZE { self.inner[self.head..].fill(0); self.head = 0 } self.inner[self.head..self.head + content.len()].copy_from_slice(content.as_bytes()); self.head += content.len(); self.inner[self.head..self.head + 2].copy_from_slice(&(content.len() as u16).to_le_bytes()); self.inner[self.head + 2] = author + 1; self.head += 3; self.next_msg += 1; } /// Iter messages from present to past fn iter(&self) -> MsgsIter { if self.head == 0 { MsgsIter { msgs: self, head: 0, current_id: 0, finished: true, } } else { MsgsIter { msgs: self, head: self.head, current_id: self.next_msg - 1, finished: false, } } } } #[derive(Debug)] struct MsgsIter<'a> { msgs: &'a Msgs, /// next byte index head: usize, finished: bool, current_id: usize, } impl<'a> Iterator for MsgsIter<'a> { type Item = Msg<'a>; /// We trust msgs.inner validity in this function, it might panic or do UB if msgs.inner is not valid fn next(&mut self) -> Option { if self.finished { return None; } if self.head == 0 { self.head = MSGS_SIZE; } let above = self.head > self.msgs.head; if above && self.head < self.msgs.head + 3 { self.finished = true; return None; } let author = self.msgs.inner[self.head - 1]; self.head -= 1; if author == 0 { return self.next(); } let author = author - 1; let len = u16::from_le_bytes([ self.msgs.inner[self.head - 2], self.msgs.inner[self.head - 1], ]) as usize; self.head -= 2; let content = unsafe { str::from_utf8_unchecked(&self.msgs.inner[self.head - len..self.head]) }; self.head -= len; if above && self.head < self.msgs.head { self.finished = true; return None; } let id = self.current_id; if self.current_id == 0 { self.finished = true; } else { self.current_id -= 1; } Some(Msg { id, author, content, }) } } static MSGS: Mutex = Mutex::new(Msgs::default()); const USERNAME_MAX_LEN: usize = 16; #[derive(Serialize)] pub struct Usernames([Option>; USERS_LEN as usize]); impl Usernames { const fn default() -> Self { Self([None, None, None, None]) } pub fn get_id(&mut self, name: &str) -> Option { for (i, un) in self.0.iter().enumerate() { if let Some(n) = un { if n.as_str() == name { return Some(i as u8); } } } for (i, un) in self.0.iter_mut().enumerate() { if *un == None { *un = Some(String::new()); un.as_mut().unwrap().push_str(name).unwrap(); USERNAMES_VERSION.add(1, Ordering::Relaxed); return Some(i as u8); } } None } } pub static USERNAMES: Mutex = Mutex::new(Usernames::default()); pub static USERNAMES_VERSION: AtomicUsize = AtomicUsize::new(0); pub struct ChatApp { id: u8, /// Id of the next message to send to client (so that 0 means no message has been sent) next_msg: usize, usernames_version: usize, } impl ChatApp { pub fn new(id: u8) -> Self { Self { id, next_msg: 0, usernames_version: 0, } } } impl App for ChatApp { fn socket_name(&self) -> &'static str { "chat" } fn accept_ws(&self, path: &str) -> bool { path == "/" } async fn handle_ws( &mut self, _path: &str, mut ws: crate::socket::ws::Ws<'_, BUF_SIZE>, ) { self.usernames_version = 0; self.next_msg = 0; Timer::after_millis(500).await; let r: Result<(), ()> = try { loop { Timer::after_millis(1).await; { let uv = USERNAMES_VERSION.load(Ordering::Relaxed); if self.usernames_version < uv { ws.send_json(&(*USERNAMES.lock().await)).await?; } self.usernames_version = uv; } { let msgs = MSGS.lock().await; for m in msgs.iter() { if m.id >= self.next_msg { ws.send_json(&m).await?; } } self.next_msg = msgs.next_msg; } if ws.last_msg.elapsed() >= Duration::from_secs(5) { ws.send(WsMsg::Ping(&[])).await?; } while let Some(r) = ws.rcv().await? { info!("{:?}", r); if let WsMsg::Text(r) = r { if r.starts_with("send ") { if r.len() > 5 + MSG_MAX_SIZE { warn!("Message too long! (len={})", r.len() - 5); return; } { MSGS.lock() .await .push(self.id, r.get(5..).unwrap_or_default()); } } } } } }; if r.is_err() { warn!( "Socket {}: error in ws, terminating connection", self.socket_name() ); } } } pub async fn id_to_static_str(id: u8) -> &'static str { match id { 0 => "0", 1 => "1", 2 => "2", 3 => "3", 4 => "4", 5 => "5", 6 => "6", 7 => "7", 8 => "8", 9 => "9", 10 => "10", 11 => "11", 12 => "12", 13 => "13", 14 => "14", 15 => "15", _ => defmt::unimplemented!(), } }