diff --git a/src/apps/chat.rs b/src/apps/chat.rs index eb9c3f4..c100e67 100644 --- a/src/apps/chat.rs +++ b/src/apps/chat.rs @@ -34,7 +34,6 @@ impl App for ChatApp { } async fn handle_request<'a>( &'a mut self, - _host: &str, path: &str, req_type: HttpRequestType, content: &str, diff --git a/src/apps/index.rs b/src/apps/index.rs index ad6d393..8c4c254 100644 --- a/src/apps/index.rs +++ b/src/apps/index.rs @@ -9,7 +9,6 @@ impl App for IndexApp { } async fn handle_request<'a>( &'a mut self, - _host: &str, path: &str, _req_type: HttpRequestType, _content: &str, diff --git a/src/apps/mod.rs b/src/apps/mod.rs index 8ff1aa4..80180e1 100644 --- a/src/apps/mod.rs +++ b/src/apps/mod.rs @@ -1,6 +1,6 @@ -use crate::socket::{HttpRequestType, HttpResCode}; +use crate::socket::{HttpRequestType, HttpResCode, ws::Ws}; -#[cfg(feature="chat")] +#[cfg(feature = "chat")] pub mod chat; pub mod index; #[cfg(feature = "ttt")] @@ -10,9 +10,19 @@ pub trait App { fn socket_name(&self) -> &'static str; async fn handle_request<'a>( &'a mut self, - host: &str, - path: &str, - req_type: HttpRequestType, - content: &str, - ) -> (HttpResCode, &'static str, &'a str); + _path: &str, + _req_type: HttpRequestType, + _content: &str, + ) -> (HttpResCode, &'static str, &'a str) { + (HttpResCode::NotFound, "", "") + } + fn accept_ws(&self, _path: &str) -> bool { + false + } + async fn handle_ws<'a, const BUF_SIZE: usize, const RES_HEAD_BUF_SIZE: usize>( + &'a mut self, + _path: &str, + _ws: Ws<'a, BUF_SIZE, RES_HEAD_BUF_SIZE>, + ) { + } } diff --git a/src/apps/ttt.rs b/src/apps/ttt.rs index 73d07cd..bc3a7b6 100644 --- a/src/apps/ttt.rs +++ b/src/apps/ttt.rs @@ -152,7 +152,6 @@ impl App for TttApp { } async fn handle_request<'a>( &'a mut self, - _host: &str, path: &str, _req_type: HttpRequestType, _content: &str, diff --git a/src/main.rs b/src/main.rs index 46bd2b8..5e609dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,12 +18,11 @@ use embassy_rp::peripherals::USB; use embassy_rp::peripherals::{DMA_CH0, PIO0}; use embassy_rp::pio::{InterruptHandler as PioInterruptHandler, Pio}; use embassy_rp::usb::{Driver, InterruptHandler as UsbInterruptHandler}; +use log::info; use pico_website::unwrap; use rand_core::RngCore; use static_cell::StaticCell; use {defmt_rtt as _, panic_probe as _}; -use log::info; -use embassy_time::Timer; #[cfg(feature = "dhcp")] mod dhcp; diff --git a/src/socket.rs b/src/socket.rs index 2f2c700..f223d54 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -1,7 +1,6 @@ use base64::{EncodeSliceError, prelude::*}; -use core::fmt::Write; use core::str::from_utf8; -use defmt::dbg; +use core::{fmt::Write, str::FromStr}; use embassy_net::tcp::TcpSocket; use embassy_time::{Duration, Timer}; use embedded_io_async::Write as _; @@ -9,7 +8,7 @@ use heapless::{String, Vec}; use log::{info, warn}; use sha1::{Digest, Sha1}; -use crate::apps; +use crate::{apps, socket::ws::Ws}; pub mod ws; @@ -38,8 +37,7 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: let mut rx_buffer = [0; 1024]; let mut tx_buffer = [0; 2048]; let mut buf = [0; 1024]; - let mut res_head_buf = Vec::::new(); - + let mut head_buf = Vec::::new(); loop { Timer::after_secs(0).await; let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); @@ -70,7 +68,7 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: break; } }; - + head_buf.clear(); let (headers, content) = match from_utf8(&buf[..n]) { Ok(b) => match b.split_once("\r\n\r\n") { Some(t) => t, @@ -146,31 +144,44 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: ); Timer::after_secs(0).await; - res_head_buf.clear(); + head_buf.clear(); let res_content: Result<&str, core::fmt::Error> = try { if ws_handshake { - let Some(key) = ws_key else { - warn!("No ws key!"); - break; - }; - let accept = match compute_ws_accept(key).await { - Ok(a) => a, - Err(e) => { - warn!("compute ws accept error : {:?}", e); + if !app.accept_ws(path) { + write!( + &mut head_buf, + "{}\r\n\r\n", + Into::<&str>::into(HttpResCode::NotFound) + )?; + "" + } else { + if path.len() > 16 { + warn!("Ws socket cannot have path longer than 16 chars!"); break; } - }; - write!( - &mut res_head_buf, - "{}\r\n\ - Upgrade: websocket\r\n\ - Connection: Upgrade\r\n\ - Sec-WebSocket-Accept: {}\r\n\r\n", - // Sec-WebSocket-Protocol: chat\r\n - Into::<&str>::into(HttpResCode::SwitchingProtocols), - accept - )?; - "" + let Some(key) = ws_key else { + warn!("No ws key!"); + break; + }; + let accept = match compute_ws_accept(key).await { + Ok(a) => a, + Err(e) => { + warn!("compute ws accept error : {:?}", e); + break; + } + }; + write!( + &mut head_buf, + "{}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Accept: {}\r\n\r\n", + // Sec-WebSocket-Protocol: chat\r\n + Into::<&str>::into(HttpResCode::SwitchingProtocols), + accept + )?; + "" + } } else { let (code, res_type, res_content): (HttpResCode, &str, &str) = match path { "/htmx.js" => ( @@ -181,21 +192,21 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: #[cfg(not(debug_assertions))] include_bytes!("../static/htmx.min.js"), ), - _ => app.handle_request(host, path, request_type, content).await, + _ => app.handle_request(path, request_type, content).await, }; - write!(&mut res_head_buf, "{}", Into::<&str>::into(code))?; + write!(&mut head_buf, "{}", Into::<&str>::into(code))?; if res_type.len() > 0 { write!( - &mut res_head_buf, + &mut head_buf, "\r\n\ - Content-Type: text/{}\r\n\ - Content-Length: {}\r\n", + Content-Type: text/{}\r\n\ + Content-Length: {}\r\n", res_type, res_content.len() )?; } - write!(&mut res_head_buf, "\r\n\r\n")?; + write!(&mut head_buf, "\r\n\r\n")?; res_content } }; @@ -208,9 +219,9 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: } }; - info!("\n{}\n", from_utf8(&res_head_buf).unwrap()); + info!("\n{}\n", from_utf8(&head_buf).unwrap()); - match socket.write_all(&res_head_buf).await { + match socket.write_all(&head_buf).await { Ok(()) => {} Err(e) => { warn!("write error: {:?}", e); @@ -224,6 +235,16 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps: break; } }; + + if ws_handshake { + let path: String<16> = String::from_str(path).unwrap(); + app.handle_ws( + &path, + Ws::new(&mut socket, &mut buf, &mut head_buf, app.socket_name()), + ) + .await; + break; + } } } } diff --git a/src/socket/ws.rs b/src/socket/ws.rs index e9e75b6..c8119b2 100644 --- a/src/socket/ws.rs +++ b/src/socket/ws.rs @@ -1,6 +1,187 @@ -// pub struct Ws { +use core::str::from_utf8; -// } -// impl Ws { -// pub fn handshake -// } +use embassy_net::tcp::{TcpReader, TcpSocket, TcpWriter}; +use embassy_time::Instant; +use embedded_io_async::{ErrorType, ReadReady, Write}; +use heapless::Vec; +use log::{info, warn}; + +#[derive(Clone, Copy)] +pub enum WsMsg<'a> { + Ping(&'a [u8]), + Pong(&'a [u8]), + Text(&'a str), + Bytes(&'a [u8]), + Unknown(u8, &'a [u8]), +} +impl WsMsg<'_> { + pub const fn len(&self) -> usize { + self.as_bytes().len() + } + pub const fn as_bytes(&self) -> &[u8] { + match self { + WsMsg::Text(t) => t.as_bytes(), + WsMsg::Bytes(b) | WsMsg::Pong(b) | WsMsg::Ping(b) | WsMsg::Unknown(_, b) => b, + } + } + pub const fn code(&self) -> u8 { + match self { + WsMsg::Text(_) => 1, + WsMsg::Bytes(_) => 2, + WsMsg::Ping(_) => 9, + WsMsg::Pong(_) => 10, + WsMsg::Unknown(c, _) => *c, + } + } +} + +struct WsRx<'a, const BUF_SIZE: usize> { + socket: TcpReader<'a>, + buf: &'a mut [u8; BUF_SIZE], + last_msg: Instant, +} +struct WsTx<'a, const HEAD_BUF_SIZE: usize> { + socket: TcpWriter<'a>, + head_buf: &'a mut Vec, +} +impl<'a, const HEAD_BUF_SIZE: usize> WsTx<'a, HEAD_BUF_SIZE> { + pub async fn send<'m>(&mut self, msg: WsMsg<'m>) -> Result<(), ()> { + self.head_buf.clear(); + self.head_buf.push(0b1000_0000 | msg.code()).unwrap(); + if msg.len() < 126 { + self.head_buf.push(msg.len() as u8).unwrap(); + } else { + self.head_buf.push(0b0111_1110).unwrap(); + self.head_buf + .extend_from_slice(&(msg.len() as u16).to_le_bytes()) + .unwrap(); + self.head_buf.extend_from_slice(msg.as_bytes()).unwrap(); + } + let w: Result<(), as ErrorType>::Error> = try { + self.socket.write_all(&self.head_buf).await?; + self.socket.write_all(msg.as_bytes()).await?; + }; + w.map_err(|e| { + warn!("write error: {:?}", e); + () + }) + } +} + +pub struct Ws<'a, const BUF_SIZE: usize = 1024, const RES_HEAD_BUF_SIZE: usize = 256> { + rx: WsRx<'a, BUF_SIZE>, + tx: WsTx<'a, RES_HEAD_BUF_SIZE>, + name: &'a str, +} +impl<'a, const BUF_SIZE: usize, const HEAD_BUF_SIZE: usize> Ws<'a, BUF_SIZE, HEAD_BUF_SIZE> { + pub fn new( + socket: &'a mut TcpSocket, + buf: &'a mut [u8; BUF_SIZE], + head_buf: &'a mut Vec, + name: &'a str, + ) -> Self { + let (rx, tx) = socket.split(); + Self { + rx: WsRx { + socket: rx, + buf, + last_msg: Instant::MIN, + }, + tx: WsTx { + socket: tx, + head_buf, + }, + name, + } + } + // Do this often to respond to pings + async fn rcv(&mut self) -> Result, ()> { + if !self.rx.socket.read_ready().unwrap() { + return Ok(None); + } + let n = match self.rx.socket.read(self.rx.buf).await { + Ok(0) => { + warn!("read EOF"); + return Err(()); + } + Ok(n) => n, + Err(e) => { + warn!("Socket {}: read error: {:?}", self.name, e); + return Err(()); + } + }; + if self.rx.buf[0] & 0b1000_0000 == 0 { + warn!("Fragmented ws messages are not supported!"); + return Err(()); + } + if self.rx.buf[0] & 0b0111_0000 != 0 { + warn!( + "Reserved ws bits are set : {}", + (self.rx.buf[0] >> 4) & 0b0111 + ); + return Err(()); + } + let (length, n_after_length) = match self.rx.buf[1] & 0b0111_1111 { + 126 => ( + u64::from_le_bytes([0, 0, 0, 0, 0, 0, self.rx.buf[2], self.rx.buf[3]]), + 4, + ), + 127 => ( + u64::from_le_bytes(self.rx.buf[2..10].try_into().unwrap()), + 10, + ), + l => (l as u64, 2), + }; + if length > 512 { + warn!("ws payload bigger than 512!"); + return Err(()); + } + + let content = if self.rx.buf[1] & 0b1000_0000 != 0 { + // masked message + if n_after_length + 4 + length as usize > n { + warn!("ws payload smaller than length"); + return Err(()); + } + let mask_key: [u8; 4] = self.rx.buf[n_after_length..n_after_length + 4] + .try_into() + .unwrap(); + for (i, x) in self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize] + .iter_mut() + .enumerate() + { + *x ^= mask_key[i & 0xff]; + } + &self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize] + } else { + if n_after_length + length as usize > n { + warn!("ws payload smaller than length"); + return Err(()); + } + &self.rx.buf[n_after_length..n_after_length + length as usize] + }; + self.rx.last_msg = Instant::now(); + match self.rx.buf[0] & 0b0000_1111 { + // Text message + 1 => { + let content = from_utf8(&content).map_err(|_| ())?; + info!("Received text : {:?}", content); + Ok(Some(WsMsg::Text(content))) + } + // Ping + 9 => { + self.tx.send(WsMsg::Pong(&content)).await?; + Ok(Some(WsMsg::Ping(&content))) + } + // Pong + 10 => Ok(Some(WsMsg::Pong(&content))), + c => { + info!("Unknown ws op code (ignoring) : {}", c); + Ok(Some(WsMsg::Unknown(c, &content))) + } + } + } + pub async fn send(&mut self, msg: WsMsg<'a>) -> Result<(), ()> { + self.tx.send(msg).await + } +}