pico-website/src/apps/chat.rs
2025-09-03 22:04:01 +02:00

264 lines
7.5 KiB
Rust

use core::sync::atomic::Ordering;
use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex};
use embassy_time::{Duration, Timer};
use heapless::String;
// use log::{info, warn};
// use pico_website::unimplemented;
use defmt::*;
use portable_atomic::AtomicUsize;
use serde::Serialize;
use crate::{apps::App, socket::ws::WsMsg};
// Must be <= u8::MAX-1;
pub const USERS_LEN: u8 = 4;
const MSG_MAX_SIZE: usize = 500;
#[derive(Debug, Serialize)]
struct Msg<'a> {
id: usize,
author: u8,
content: &'a str,
}
const MSGS_SIZE: usize = 100000;
const _: () = core::assert!(MSGS_SIZE > MSG_MAX_SIZE);
#[derive(Debug)]
struct Msgs {
/// * Memory layout with sizes in bytes : ...|content: len|len: 2|author+1: 1|...
/// * `author=0` means theres no message, it's just padding and should be skipped.
/// * No message is splitted
inner: [u8; MSGS_SIZE],
/// next byte index
head: usize,
next_msg: usize,
}
impl Msgs {
const fn default() -> Self {
Self {
inner: [0; _],
head: 0,
next_msg: 0,
}
}
fn push(&mut self, author: u8, content: &str) {
if self.head + content.len() + 3 >= MSGS_SIZE {
self.inner[self.head..].fill(0);
self.head = 0
}
self.inner[self.head..self.head + content.len()].copy_from_slice(content.as_bytes());
self.head += content.len();
self.inner[self.head..self.head + 2].copy_from_slice(&(content.len() as u16).to_le_bytes());
self.inner[self.head + 2] = author + 1;
self.head += 3;
self.next_msg += 1;
}
/// Iter messages from present to past
fn iter(&self) -> MsgsIter {
if self.head == 0 {
MsgsIter {
msgs: self,
head: 0,
current_id: 0,
finished: true,
}
} else {
MsgsIter {
msgs: self,
head: self.head,
current_id: self.next_msg - 1,
finished: false,
}
}
}
}
#[derive(Debug)]
struct MsgsIter<'a> {
msgs: &'a Msgs,
/// next byte index
head: usize,
finished: bool,
current_id: usize,
}
impl<'a> Iterator for MsgsIter<'a> {
type Item = Msg<'a>;
/// We trust msgs.inner validity in this function, it might panic or do UB if msgs.inner is not valid
fn next(&mut self) -> Option<Self::Item> {
if self.finished {
return None;
}
if self.head == 0 {
self.head = MSGS_SIZE;
}
let above = self.head > self.msgs.head;
if above && self.head < self.msgs.head + 3 {
self.finished = true;
return None;
}
let author = self.msgs.inner[self.head - 1];
self.head -= 1;
if author == 0 {
return self.next();
}
let author = author - 1;
let len = u16::from_le_bytes([
self.msgs.inner[self.head - 2],
self.msgs.inner[self.head - 1],
]) as usize;
self.head -= 2;
let content =
unsafe { str::from_utf8_unchecked(&self.msgs.inner[self.head - len..self.head]) };
self.head -= len;
if above && self.head < self.msgs.head {
self.finished = true;
return None;
}
let id = self.current_id;
if self.current_id == 0 {
self.finished = true;
} else {
self.current_id -= 1;
}
Some(Msg {
id,
author,
content,
})
}
}
static MSGS: Mutex<ThreadModeRawMutex, Msgs> = Mutex::new(Msgs::default());
const USERNAME_MAX_LEN: usize = 16;
#[derive(Serialize)]
pub struct Usernames([Option<String<USERNAME_MAX_LEN>>; USERS_LEN as usize]);
impl Usernames {
const fn default() -> Self {
Self([None, None, None, None])
}
pub fn get_id(&mut self, name: &str) -> Option<u8> {
for (i, un) in self.0.iter().enumerate() {
if let Some(n) = un {
if n.as_str() == name {
return Some(i as u8);
}
}
}
for (i, un) in self.0.iter_mut().enumerate() {
if *un == None {
*un = Some(String::new());
un.as_mut().unwrap().push_str(name).unwrap();
USERNAMES_VERSION.add(1, Ordering::Relaxed);
return Some(i as u8);
}
}
None
}
}
pub static USERNAMES: Mutex<ThreadModeRawMutex, Usernames> = Mutex::new(Usernames::default());
pub static USERNAMES_VERSION: AtomicUsize = AtomicUsize::new(0);
pub struct ChatApp {
id: u8,
/// Id of the next message to send to client (so that 0 means no message has been sent)
next_msg: usize,
usernames_version: usize,
}
impl ChatApp {
pub fn new(id: u8) -> Self {
Self {
id,
next_msg: 0,
usernames_version: 0,
}
}
}
impl App for ChatApp {
fn socket_name(&self) -> &'static str {
"chat"
}
fn accept_ws(&self, path: &str) -> bool {
path == "/"
}
async fn handle_ws<const BUF_SIZE: usize>(
&mut self,
_path: &str,
mut ws: crate::socket::ws::Ws<'_, BUF_SIZE>,
) {
self.usernames_version = 0;
self.next_msg = 0;
Timer::after_millis(500).await;
let r: Result<(), ()> = try {
loop {
Timer::after_millis(1).await;
{
let uv = USERNAMES_VERSION.load(Ordering::Relaxed);
if self.usernames_version < uv {
ws.send_json(&(*USERNAMES.lock().await)).await?;
}
self.usernames_version = uv;
}
{
let msgs = MSGS.lock().await;
for m in msgs.iter() {
if m.id >= self.next_msg {
ws.send_json(&m).await?;
}
}
self.next_msg = msgs.next_msg;
}
if ws.last_msg.elapsed() >= Duration::from_secs(5) {
ws.send(WsMsg::Ping(&[])).await?;
}
while let Some(r) = ws.rcv().await? {
info!("{:?}", r);
if let WsMsg::Text(r) = r {
if r.starts_with("send ") {
if r.len() > 5 + MSG_MAX_SIZE {
warn!("Message too long! (len={})", r.len() - 5);
return;
}
{
MSGS.lock()
.await
.push(self.id, r.get(5..).unwrap_or_default());
}
}
}
}
}
};
if r.is_err() {
warn!(
"Socket {}: error in ws, terminating connection",
self.socket_name()
);
}
}
}
pub async fn id_to_static_str(id: u8) -> &'static str {
match id {
0 => "0",
1 => "1",
2 => "2",
3 => "3",
4 => "4",
5 => "5",
6 => "6",
7 => "7",
8 => "8",
9 => "9",
10 => "10",
11 => "11",
12 => "12",
13 => "13",
14 => "14",
15 => "15",
_ => defmt::unimplemented!(),
}
}