264 lines
7.5 KiB
Rust
264 lines
7.5 KiB
Rust
use core::sync::atomic::Ordering;
|
|
|
|
use embassy_sync::{blocking_mutex::raw::ThreadModeRawMutex, mutex::Mutex};
|
|
use embassy_time::{Duration, Timer};
|
|
use heapless::String;
|
|
// use log::{info, warn};
|
|
// use pico_website::unimplemented;
|
|
use defmt::*;
|
|
use portable_atomic::AtomicUsize;
|
|
use serde::Serialize;
|
|
|
|
use crate::{apps::App, socket::ws::WsMsg};
|
|
|
|
// Must be <= u8::MAX-1;
|
|
pub const USERS_LEN: u8 = 4;
|
|
|
|
const MSG_MAX_SIZE: usize = 500;
|
|
#[derive(Debug, Serialize)]
|
|
struct Msg<'a> {
|
|
id: usize,
|
|
author: u8,
|
|
content: &'a str,
|
|
}
|
|
|
|
const MSGS_SIZE: usize = 100000;
|
|
const _: () = core::assert!(MSGS_SIZE > MSG_MAX_SIZE);
|
|
#[derive(Debug)]
|
|
struct Msgs {
|
|
/// * Memory layout with sizes in bytes : ...|content: len|len: 2|author+1: 1|...
|
|
/// * `author=0` means theres no message, it's just padding and should be skipped.
|
|
/// * No message is splitted
|
|
inner: [u8; MSGS_SIZE],
|
|
/// next byte index
|
|
head: usize,
|
|
next_msg: usize,
|
|
}
|
|
impl Msgs {
|
|
const fn default() -> Self {
|
|
Self {
|
|
inner: [0; _],
|
|
head: 0,
|
|
next_msg: 0,
|
|
}
|
|
}
|
|
fn push(&mut self, author: u8, content: &str) {
|
|
if self.head + content.len() + 3 >= MSGS_SIZE {
|
|
self.inner[self.head..].fill(0);
|
|
self.head = 0
|
|
}
|
|
self.inner[self.head..self.head + content.len()].copy_from_slice(content.as_bytes());
|
|
self.head += content.len();
|
|
self.inner[self.head..self.head + 2].copy_from_slice(&(content.len() as u16).to_le_bytes());
|
|
self.inner[self.head + 2] = author + 1;
|
|
self.head += 3;
|
|
self.next_msg += 1;
|
|
}
|
|
/// Iter messages from present to past
|
|
fn iter(&self) -> MsgsIter {
|
|
if self.head == 0 {
|
|
MsgsIter {
|
|
msgs: self,
|
|
head: 0,
|
|
current_id: 0,
|
|
finished: true,
|
|
}
|
|
} else {
|
|
MsgsIter {
|
|
msgs: self,
|
|
head: self.head,
|
|
current_id: self.next_msg - 1,
|
|
finished: false,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct MsgsIter<'a> {
|
|
msgs: &'a Msgs,
|
|
/// next byte index
|
|
head: usize,
|
|
finished: bool,
|
|
current_id: usize,
|
|
}
|
|
impl<'a> Iterator for MsgsIter<'a> {
|
|
type Item = Msg<'a>;
|
|
/// We trust msgs.inner validity in this function, it might panic or do UB if msgs.inner is not valid
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
if self.finished {
|
|
return None;
|
|
}
|
|
if self.head == 0 {
|
|
self.head = MSGS_SIZE;
|
|
}
|
|
let above = self.head > self.msgs.head;
|
|
if above && self.head < self.msgs.head + 3 {
|
|
self.finished = true;
|
|
return None;
|
|
}
|
|
let author = self.msgs.inner[self.head - 1];
|
|
self.head -= 1;
|
|
if author == 0 {
|
|
return self.next();
|
|
}
|
|
let author = author - 1;
|
|
let len = u16::from_le_bytes([
|
|
self.msgs.inner[self.head - 2],
|
|
self.msgs.inner[self.head - 1],
|
|
]) as usize;
|
|
self.head -= 2;
|
|
|
|
let content =
|
|
unsafe { str::from_utf8_unchecked(&self.msgs.inner[self.head - len..self.head]) };
|
|
self.head -= len;
|
|
if above && self.head < self.msgs.head {
|
|
self.finished = true;
|
|
return None;
|
|
}
|
|
let id = self.current_id;
|
|
if self.current_id == 0 {
|
|
self.finished = true;
|
|
} else {
|
|
self.current_id -= 1;
|
|
}
|
|
|
|
Some(Msg {
|
|
id,
|
|
author,
|
|
content,
|
|
})
|
|
}
|
|
}
|
|
|
|
static MSGS: Mutex<ThreadModeRawMutex, Msgs> = Mutex::new(Msgs::default());
|
|
const USERNAME_MAX_LEN: usize = 16;
|
|
#[derive(Serialize)]
|
|
pub struct Usernames([Option<String<USERNAME_MAX_LEN>>; USERS_LEN as usize]);
|
|
impl Usernames {
|
|
const fn default() -> Self {
|
|
Self([None, None, None, None])
|
|
}
|
|
pub fn get_id(&mut self, name: &str) -> Option<u8> {
|
|
for (i, un) in self.0.iter().enumerate() {
|
|
if let Some(n) = un {
|
|
if n.as_str() == name {
|
|
return Some(i as u8);
|
|
}
|
|
}
|
|
}
|
|
for (i, un) in self.0.iter_mut().enumerate() {
|
|
if *un == None {
|
|
*un = Some(String::new());
|
|
un.as_mut().unwrap().push_str(name).unwrap();
|
|
USERNAMES_VERSION.add(1, Ordering::Relaxed);
|
|
return Some(i as u8);
|
|
}
|
|
}
|
|
None
|
|
}
|
|
}
|
|
pub static USERNAMES: Mutex<ThreadModeRawMutex, Usernames> = Mutex::new(Usernames::default());
|
|
pub static USERNAMES_VERSION: AtomicUsize = AtomicUsize::new(0);
|
|
|
|
pub struct ChatApp {
|
|
id: u8,
|
|
/// Id of the next message to send to client (so that 0 means no message has been sent)
|
|
next_msg: usize,
|
|
usernames_version: usize,
|
|
}
|
|
impl ChatApp {
|
|
pub fn new(id: u8) -> Self {
|
|
Self {
|
|
id,
|
|
next_msg: 0,
|
|
usernames_version: 0,
|
|
}
|
|
}
|
|
}
|
|
impl App for ChatApp {
|
|
fn socket_name(&self) -> &'static str {
|
|
"chat"
|
|
}
|
|
fn accept_ws(&self, path: &str) -> bool {
|
|
path == "/"
|
|
}
|
|
async fn handle_ws<const BUF_SIZE: usize>(
|
|
&mut self,
|
|
_path: &str,
|
|
mut ws: crate::socket::ws::Ws<'_, BUF_SIZE>,
|
|
) {
|
|
self.usernames_version = 0;
|
|
self.next_msg = 0;
|
|
Timer::after_millis(500).await;
|
|
let r: Result<(), ()> = try {
|
|
loop {
|
|
Timer::after_millis(1).await;
|
|
{
|
|
let uv = USERNAMES_VERSION.load(Ordering::Relaxed);
|
|
if self.usernames_version < uv {
|
|
ws.send_json(&(*USERNAMES.lock().await)).await?;
|
|
}
|
|
self.usernames_version = uv;
|
|
}
|
|
{
|
|
let msgs = MSGS.lock().await;
|
|
for m in msgs.iter() {
|
|
if m.id >= self.next_msg {
|
|
ws.send_json(&m).await?;
|
|
}
|
|
}
|
|
self.next_msg = msgs.next_msg;
|
|
}
|
|
if ws.last_msg.elapsed() >= Duration::from_secs(5) {
|
|
ws.send(WsMsg::Ping(&[])).await?;
|
|
}
|
|
while let Some(r) = ws.rcv().await? {
|
|
info!("{:?}", r);
|
|
if let WsMsg::Text(r) = r {
|
|
if r.starts_with("send ") {
|
|
if r.len() > 5 + MSG_MAX_SIZE {
|
|
warn!("Message too long! (len={})", r.len() - 5);
|
|
return;
|
|
}
|
|
{
|
|
MSGS.lock()
|
|
.await
|
|
.push(self.id, r.get(5..).unwrap_or_default());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
if r.is_err() {
|
|
warn!(
|
|
"Socket {}: error in ws, terminating connection",
|
|
self.socket_name()
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn id_to_static_str(id: u8) -> &'static str {
|
|
match id {
|
|
0 => "0",
|
|
1 => "1",
|
|
2 => "2",
|
|
3 => "3",
|
|
4 => "4",
|
|
5 => "5",
|
|
6 => "6",
|
|
7 => "7",
|
|
8 => "8",
|
|
9 => "9",
|
|
10 => "10",
|
|
11 => "11",
|
|
12 => "12",
|
|
13 => "13",
|
|
14 => "14",
|
|
15 => "15",
|
|
_ => defmt::unimplemented!(),
|
|
}
|
|
}
|