diff --git a/src/apps/chat.rs b/src/apps/chat.rs index c39d9f8..b825ffe 100644 --- a/src/apps/chat.rs +++ b/src/apps/chat.rs @@ -181,10 +181,10 @@ impl App for ChatApp { fn accept_ws(&self, path: &str) -> bool { path == "/" } - async fn handle_ws<'a, const BUF_SIZE: usize>( - &'a mut self, + async fn handle_ws( + &mut self, _path: &str, - mut ws: crate::socket::ws::Ws<'a, BUF_SIZE>, + mut ws: crate::socket::ws::Ws<'_, BUF_SIZE>, ) { Timer::after_millis(500).await; let r: Result<(), ()> = try { diff --git a/src/apps/mod.rs b/src/apps/mod.rs index c65d724..996f1a2 100644 --- a/src/apps/mod.rs +++ b/src/apps/mod.rs @@ -21,12 +21,7 @@ pub trait App { fn accept_ws(&self, _path: &str) -> bool { false } - async fn handle_ws<'a, const BUF_SIZE: usize>( - &'a mut self, - _path: &str, - _ws: Ws<'a, BUF_SIZE>, - ) { - } + async fn handle_ws(&mut self, _path: &str, _ws: Ws<'_, BUF_SIZE>) {} } pub struct Content<'a>(pub Vec<&'a str, 8>); diff --git a/src/apps/ttt.rs b/src/apps/ttt.rs index ee15171..b4f3d4e 100644 --- a/src/apps/ttt.rs +++ b/src/apps/ttt.rs @@ -13,8 +13,6 @@ use crate::socket::{HttpRequestType, HttpResCode}; use super::App; -// bits [0; 8] : player zero board / bits [9; 17] : player one board / is_ended [18] / is_draw [19] / winner [20]: 0=blue 1=green / current_turn [21]: 0=blue 1=green - #[derive(Debug, Serialize, Clone, PartialEq, Eq)] struct Game { board: [Option; 9], diff --git a/src/socket.rs b/src/socket.rs index 7eb61d4..460dbd1 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -1,12 +1,12 @@ -use base64::{EncodeSliceError, prelude::*}; +use base64::prelude::*; +use core::fmt::Write; use core::str::from_utf8; -use core::{fmt::Write, str::FromStr}; use embassy_net::tcp::TcpSocket; use embassy_time::{Duration, Timer}; use embedded_io_async::Write as _; use heapless::{String, Vec}; use log::{info, warn}; -use pico_website::{unwrap, unwrap_opt}; +use pico_website::unwrap; use sha1::{Digest, Sha1}; use crate::apps::Content; @@ -17,24 +17,32 @@ pub mod ws; #[cfg(feature = "ttt")] #[embassy_executor::task(pool_size = 2)] pub async fn ttt_listen_task(stack: embassy_net::Stack<'static>, team: apps::ttt::Team, port: u16) { - listen_task(stack, apps::ttt::TttApp::new(team), port).await + listen_task::<32, 32, 256, 256>(stack, apps::ttt::TttApp::new(team), port).await } #[embassy_executor::task(pool_size = 2)] pub async fn index_listen_task(stack: embassy_net::Stack<'static>, port: u16) { - listen_task(stack, apps::index::IndexApp, port).await + listen_task::<64, 0, 1024, 1024>(stack, apps::index::IndexApp, port).await } #[cfg(feature = "chat")] #[embassy_executor::task(pool_size = 4)] pub async fn chat_listen_task(stack: embassy_net::Stack<'static>, id: u8, port: u16) { - listen_task(stack, apps::chat::ChatApp::new(id), port).await + listen_task::<64, 512, 512, 512>(stack, apps::chat::ChatApp::new(id), port).await } -pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps::App, port: u16) { - let mut rx_buffer = [0; 1024]; - let mut tx_buffer = [0; 2048]; - let mut buf = [0; 1024]; +pub async fn listen_task< + const PATH_LEN: usize, + const BUF_LEN: usize, + const RX_LEN: usize, + const TX_LEN: usize, +>( + stack: embassy_net::Stack<'static>, + mut app: impl apps::App, + port: u16, +) { + let mut rx_buffer = [0; RX_LEN]; + let mut tx_buffer = [0; TX_LEN]; let mut head_buf = Vec::::new(); loop { Timer::after_secs(0).await; @@ -54,31 +62,39 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: ); let (mut rx, mut tx) = socket.split(); - let mut ws_path: Option> = None; + let mut buf = String::::new(); + let mut path = String::::new(); + let mut request_type = HttpRequestType::Get; + let mut is_ws = false; loop { Timer::after_secs(0).await; - let mut cont = &[]; - rx.read_with(|msg| { - let (headers, content) = match from_utf8(msg) { - Ok(b) => match b.split_once("\r\n\r\n") { - Some(t) => t, - None => (b, ""), - }, - Err(_) => { - warn!("Non utf8 http request"); + + match rx + .read_with(|msg| { + let (headers, content) = match from_utf8(msg) { + Ok(b) => match b.split_once("\r\n\r\n") { + Some(t) => t, + None => (b, ""), + }, + Err(_) => { + warn!("Non utf8 http request"); + return (0, Err(())); + } + }; + buf.clear(); + if let Err(_) = buf.push_str(content) { + warn!("Received content is bigger than maximum content!"); return (0, Err(())); } - }; - let mut hl = headers.lines(); - let (request_type, path) = match hl.next() { - None => { - warn!("Empty request"); - return (0, Err(())); - } - Some(l1) => { - let mut l1 = l1.split(' '); - ( - match l1.next() { + let mut hl = headers.lines(); + match hl.next() { + None => { + warn!("Empty request"); + return (0, Err(())); + } + Some(l1) => { + let mut l1 = l1.split(' '); + request_type = match l1.next() { Some("GET") => HttpRequestType::Get, Some("POST") => HttpRequestType::Post, Some(t) => { @@ -89,110 +105,119 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: warn!("No request type"); return (0, Err(())); } - }, - match l1.next() { + }; + path.clear(); + if let Err(_) = path.push_str(match l1.next() { Some(path) => path, None => { warn!("No path"); return (0, Err(())); } - }, - ) - } - }; - let mut host = None; - let mut ws_handshake = false; - let mut ws_key = None; - for h in hl { - let Some((name, val)) = h.split_once(':') else { - continue; - }; - let name = name.trim(); - let val = val.trim(); - match (name, val) { - ("Host", _) => host = Some(val), - ("Upgrade", "websocket") => ws_handshake = true, - ("Sec-WebSocket-Key", _) => ws_key = Some(val), - _ => {} - } - } - let Some(host) = host else { - warn!("No host"); - return (0, Err(())); - }; - info!( - "Socket {}: {:?}{} request for {}{}", - app.socket_name(), - request_type, - if ws_handshake { " websocket" } else { "" }, - host, - path, - ); - head_buf.clear(); - let res_content: Result, core::fmt::Error> = try { - if ws_handshake { - if !app.accept_ws(path) { - warn!("No ws there!"); - write!( - &mut head_buf, - "{}\r\n\r\n", - Into::<&str>::into(HttpResCode::NotFound) - )?; - None - } else { - if path.len() > 16 { - warn!("Ws socket cannot have path longer than 16 chars!"); + }) { + warn!("Path is too big!"); return (0, Err(())); } - let Some(key) = ws_key else { - warn!("No ws key!"); - return (0, Err(())); - }; - let Ok(accept) = compute_ws_accept(key) else { - return (0, Err(())); - }; - write!( - &mut head_buf, - "{}\r\n\ + } + }; + let mut host = None; + + let mut ws_key = None; + for h in hl { + let Some((name, val)) = h.split_once(':') else { + continue; + }; + let name = name.trim(); + let val = val.trim(); + match (name, val) { + ("Host", _) => host = Some(val), + ("Upgrade", "websocket") => is_ws = true, + ("Sec-WebSocket-Key", _) => ws_key = Some(val), + _ => {} + } + } + let Some(host) = host else { + warn!("No host!"); + return (0, Err(())); + }; + info!( + "Socket {}: {:?}{} request for {}{}", + app.socket_name(), + request_type, + if is_ws { " websocket" } else { "" }, + host, + path, + ); + buf.clear(); + if is_ws { + let Some(key) = ws_key else { + warn!("No ws key!"); + return (0, Err(())); + }; + if let Err(_) = buf.push_str(key) { + warn!("Ws key is too long!"); + return (0, Err(())); + } + } else { + if let Err(_) = buf.push_str(content) { + warn!("Content is too long!"); + return (0, Err(())); + } + } + (msg.len(), Ok(())) + }) + .await + { + Ok(Ok(())) => {} + Ok(Err(())) => break, + Err(e) => { + warn!("Error while receiving : {:?}", e); + break; + } + }; + + head_buf.clear(); + let res_content: Result, core::fmt::Error> = try { + if is_ws { + if !app.accept_ws(&path) { + warn!("No ws there!"); + write!( + &mut head_buf, + "{}\r\n\r\n", + Into::<&str>::into(HttpResCode::NotFound) + )?; + None + } else { + let Ok(accept) = compute_ws_accept(&buf) else { + break; + }; + write!( + &mut head_buf, + "{}\r\n\ Upgrade: websocket\r\n\ Connection: Upgrade\r\n\ Sec-WebSocket-Accept: {}\r\n\r\n", - Into::<&str>::into(HttpResCode::SwitchingProtocols), - accept - )?; - None - } - } else { - let (code, res_type, res_content) = - app.handle_request(path, request_type, content).await; - - write!(&mut head_buf, "{}", Into::<&str>::into(code))?; - if let Some(ref c) = res_content { - write!( - &mut head_buf, - "\r\n\ - Content-Type: text/{}\r\n\ - Content-Length: {}\r\n", - res_type, - c.len() - )?; - } - write!(&mut head_buf, "\r\n\r\n")?; - res_content + Into::<&str>::into(HttpResCode::SwitchingProtocols), + accept + )?; + None } - }; - (msg.len(), Ok(())) - }) - .await; - let n = match socket.read(&mut buf).await { - Ok(0) => { - warn!("read EOF"); - break; - } - Ok(n) => n, - Err(e) => { - warn!("Socket {}: read error: {:?}", app.socket_name(), e); - break; + } else { + let (code, res_type, res_content) = + app.handle_request(&path, request_type, &buf).await; + + write!(&mut head_buf, "{}", Into::<&str>::into(code))?; + if let Some(ref c) = res_content { + write!( + &mut head_buf, + "\r\n\ + Content-Type: text/{}\r\n\ + Content-Length: {}\r\n", + res_type, + c.len() + )?; + } + write!(&mut head_buf, "\r\n\r\n")?; + res_content } }; @@ -205,10 +230,10 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: }; let w: Result<(), embassy_net::tcp::Error> = try { - socket.write_all(&head_buf).await?; + tx.write_all(&head_buf).await?; if let Some(ref c) = res_content { for s in c.0.iter() { - socket.write_all(s.as_bytes()).await?; + tx.write_all(s.as_bytes()).await?; } } }; @@ -217,15 +242,19 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: warn!("write error: {:?}", e); break; }; - - if ws_handshake { - ws_path = Some(unwrap(String::from_str(path)).await); - break; - } } - if let Some(path) = ws_path { - app.handle_ws(&path, Ws::new(&mut socket, &mut buf, app.socket_name())) - .await; + if is_ws { + let mut buf = buf.into_bytes(); + unwrap(buf.resize_default(BUF_LEN)).await; + app.handle_ws::( + &path, + Ws::new( + &mut socket, + &mut unwrap(buf.into_array()).await, + app.socket_name(), + ), + ) + .await; } } }