This commit is contained in:
Arkitu 2025-07-15 16:47:16 +02:00
parent b50300fbbb
commit f136f55266
7 changed files with 260 additions and 52 deletions

View File

@ -34,7 +34,6 @@ impl App for ChatApp {
} }
async fn handle_request<'a>( async fn handle_request<'a>(
&'a mut self, &'a mut self,
_host: &str,
path: &str, path: &str,
req_type: HttpRequestType, req_type: HttpRequestType,
content: &str, content: &str,

View File

@ -9,7 +9,6 @@ impl App for IndexApp {
} }
async fn handle_request<'a>( async fn handle_request<'a>(
&'a mut self, &'a mut self,
_host: &str,
path: &str, path: &str,
_req_type: HttpRequestType, _req_type: HttpRequestType,
_content: &str, _content: &str,

View File

@ -1,6 +1,6 @@
use crate::socket::{HttpRequestType, HttpResCode}; use crate::socket::{HttpRequestType, HttpResCode, ws::Ws};
#[cfg(feature="chat")] #[cfg(feature = "chat")]
pub mod chat; pub mod chat;
pub mod index; pub mod index;
#[cfg(feature = "ttt")] #[cfg(feature = "ttt")]
@ -10,9 +10,19 @@ pub trait App {
fn socket_name(&self) -> &'static str; fn socket_name(&self) -> &'static str;
async fn handle_request<'a>( async fn handle_request<'a>(
&'a mut self, &'a mut self,
host: &str, _path: &str,
path: &str, _req_type: HttpRequestType,
req_type: HttpRequestType, _content: &str,
content: &str, ) -> (HttpResCode, &'static str, &'a str) {
) -> (HttpResCode, &'static str, &'a str); (HttpResCode::NotFound, "", "")
}
fn accept_ws(&self, _path: &str) -> bool {
false
}
async fn handle_ws<'a, const BUF_SIZE: usize, const RES_HEAD_BUF_SIZE: usize>(
&'a mut self,
_path: &str,
_ws: Ws<'a, BUF_SIZE, RES_HEAD_BUF_SIZE>,
) {
}
} }

View File

@ -152,7 +152,6 @@ impl App for TttApp {
} }
async fn handle_request<'a>( async fn handle_request<'a>(
&'a mut self, &'a mut self,
_host: &str,
path: &str, path: &str,
_req_type: HttpRequestType, _req_type: HttpRequestType,
_content: &str, _content: &str,

View File

@ -18,12 +18,11 @@ use embassy_rp::peripherals::USB;
use embassy_rp::peripherals::{DMA_CH0, PIO0}; use embassy_rp::peripherals::{DMA_CH0, PIO0};
use embassy_rp::pio::{InterruptHandler as PioInterruptHandler, Pio}; use embassy_rp::pio::{InterruptHandler as PioInterruptHandler, Pio};
use embassy_rp::usb::{Driver, InterruptHandler as UsbInterruptHandler}; use embassy_rp::usb::{Driver, InterruptHandler as UsbInterruptHandler};
use log::info;
use pico_website::unwrap; use pico_website::unwrap;
use rand_core::RngCore; use rand_core::RngCore;
use static_cell::StaticCell; use static_cell::StaticCell;
use {defmt_rtt as _, panic_probe as _}; use {defmt_rtt as _, panic_probe as _};
use log::info;
use embassy_time::Timer;
#[cfg(feature = "dhcp")] #[cfg(feature = "dhcp")]
mod dhcp; mod dhcp;

View File

@ -1,7 +1,6 @@
use base64::{EncodeSliceError, prelude::*}; use base64::{EncodeSliceError, prelude::*};
use core::fmt::Write;
use core::str::from_utf8; use core::str::from_utf8;
use defmt::dbg; use core::{fmt::Write, str::FromStr};
use embassy_net::tcp::TcpSocket; use embassy_net::tcp::TcpSocket;
use embassy_time::{Duration, Timer}; use embassy_time::{Duration, Timer};
use embedded_io_async::Write as _; use embedded_io_async::Write as _;
@ -9,7 +8,7 @@ use heapless::{String, Vec};
use log::{info, warn}; use log::{info, warn};
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
use crate::apps; use crate::{apps, socket::ws::Ws};
pub mod ws; pub mod ws;
@ -38,8 +37,7 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps:
let mut rx_buffer = [0; 1024]; let mut rx_buffer = [0; 1024];
let mut tx_buffer = [0; 2048]; let mut tx_buffer = [0; 2048];
let mut buf = [0; 1024]; let mut buf = [0; 1024];
let mut res_head_buf = Vec::<u8, 256>::new(); let mut head_buf = Vec::<u8, 256>::new();
loop { loop {
Timer::after_secs(0).await; Timer::after_secs(0).await;
let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer); let mut socket = TcpSocket::new(stack, &mut rx_buffer, &mut tx_buffer);
@ -70,7 +68,7 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps:
break; break;
} }
}; };
head_buf.clear();
let (headers, content) = match from_utf8(&buf[..n]) { let (headers, content) = match from_utf8(&buf[..n]) {
Ok(b) => match b.split_once("\r\n\r\n") { Ok(b) => match b.split_once("\r\n\r\n") {
Some(t) => t, Some(t) => t,
@ -146,31 +144,44 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps:
); );
Timer::after_secs(0).await; Timer::after_secs(0).await;
res_head_buf.clear(); head_buf.clear();
let res_content: Result<&str, core::fmt::Error> = try { let res_content: Result<&str, core::fmt::Error> = try {
if ws_handshake { if ws_handshake {
let Some(key) = ws_key else { if !app.accept_ws(path) {
warn!("No ws key!"); write!(
break; &mut head_buf,
}; "{}\r\n\r\n",
let accept = match compute_ws_accept(key).await { Into::<&str>::into(HttpResCode::NotFound)
Ok(a) => a, )?;
Err(e) => { ""
warn!("compute ws accept error : {:?}", e); } else {
if path.len() > 16 {
warn!("Ws socket cannot have path longer than 16 chars!");
break; break;
} }
}; let Some(key) = ws_key else {
write!( warn!("No ws key!");
&mut res_head_buf, break;
"{}\r\n\ };
Upgrade: websocket\r\n\ let accept = match compute_ws_accept(key).await {
Connection: Upgrade\r\n\ Ok(a) => a,
Sec-WebSocket-Accept: {}\r\n\r\n", Err(e) => {
// Sec-WebSocket-Protocol: chat\r\n warn!("compute ws accept error : {:?}", e);
Into::<&str>::into(HttpResCode::SwitchingProtocols), break;
accept }
)?; };
"" write!(
&mut head_buf,
"{}\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {}\r\n\r\n",
// Sec-WebSocket-Protocol: chat\r\n
Into::<&str>::into(HttpResCode::SwitchingProtocols),
accept
)?;
""
}
} else { } else {
let (code, res_type, res_content): (HttpResCode, &str, &str) = match path { let (code, res_type, res_content): (HttpResCode, &str, &str) = match path {
"/htmx.js" => ( "/htmx.js" => (
@ -181,21 +192,21 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps:
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
include_bytes!("../static/htmx.min.js"), include_bytes!("../static/htmx.min.js"),
), ),
_ => app.handle_request(host, path, request_type, content).await, _ => app.handle_request(path, request_type, content).await,
}; };
write!(&mut res_head_buf, "{}", Into::<&str>::into(code))?; write!(&mut head_buf, "{}", Into::<&str>::into(code))?;
if res_type.len() > 0 { if res_type.len() > 0 {
write!( write!(
&mut res_head_buf, &mut head_buf,
"\r\n\ "\r\n\
Content-Type: text/{}\r\n\ Content-Type: text/{}\r\n\
Content-Length: {}\r\n", Content-Length: {}\r\n",
res_type, res_type,
res_content.len() res_content.len()
)?; )?;
} }
write!(&mut res_head_buf, "\r\n\r\n")?; write!(&mut head_buf, "\r\n\r\n")?;
res_content res_content
} }
}; };
@ -208,9 +219,9 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps:
} }
}; };
info!("\n{}\n", from_utf8(&res_head_buf).unwrap()); info!("\n{}\n", from_utf8(&head_buf).unwrap());
match socket.write_all(&res_head_buf).await { match socket.write_all(&head_buf).await {
Ok(()) => {} Ok(()) => {}
Err(e) => { Err(e) => {
warn!("write error: {:?}", e); warn!("write error: {:?}", e);
@ -224,6 +235,16 @@ pub async fn listen_task(stack: embassy_net::Stack<'static>, mut app: impl apps:
break; break;
} }
}; };
if ws_handshake {
let path: String<16> = String::from_str(path).unwrap();
app.handle_ws(
&path,
Ws::new(&mut socket, &mut buf, &mut head_buf, app.socket_name()),
)
.await;
break;
}
} }
} }
} }

View File

@ -1,6 +1,187 @@
// pub struct Ws { use core::str::from_utf8;
// } use embassy_net::tcp::{TcpReader, TcpSocket, TcpWriter};
// impl Ws { use embassy_time::Instant;
// pub fn handshake use embedded_io_async::{ErrorType, ReadReady, Write};
// } use heapless::Vec;
use log::{info, warn};
#[derive(Clone, Copy)]
pub enum WsMsg<'a> {
Ping(&'a [u8]),
Pong(&'a [u8]),
Text(&'a str),
Bytes(&'a [u8]),
Unknown(u8, &'a [u8]),
}
impl WsMsg<'_> {
pub const fn len(&self) -> usize {
self.as_bytes().len()
}
pub const fn as_bytes(&self) -> &[u8] {
match self {
WsMsg::Text(t) => t.as_bytes(),
WsMsg::Bytes(b) | WsMsg::Pong(b) | WsMsg::Ping(b) | WsMsg::Unknown(_, b) => b,
}
}
pub const fn code(&self) -> u8 {
match self {
WsMsg::Text(_) => 1,
WsMsg::Bytes(_) => 2,
WsMsg::Ping(_) => 9,
WsMsg::Pong(_) => 10,
WsMsg::Unknown(c, _) => *c,
}
}
}
struct WsRx<'a, const BUF_SIZE: usize> {
socket: TcpReader<'a>,
buf: &'a mut [u8; BUF_SIZE],
last_msg: Instant,
}
struct WsTx<'a, const HEAD_BUF_SIZE: usize> {
socket: TcpWriter<'a>,
head_buf: &'a mut Vec<u8, HEAD_BUF_SIZE>,
}
impl<'a, const HEAD_BUF_SIZE: usize> WsTx<'a, HEAD_BUF_SIZE> {
pub async fn send<'m>(&mut self, msg: WsMsg<'m>) -> Result<(), ()> {
self.head_buf.clear();
self.head_buf.push(0b1000_0000 | msg.code()).unwrap();
if msg.len() < 126 {
self.head_buf.push(msg.len() as u8).unwrap();
} else {
self.head_buf.push(0b0111_1110).unwrap();
self.head_buf
.extend_from_slice(&(msg.len() as u16).to_le_bytes())
.unwrap();
self.head_buf.extend_from_slice(msg.as_bytes()).unwrap();
}
let w: Result<(), <TcpSocket<'_> as ErrorType>::Error> = try {
self.socket.write_all(&self.head_buf).await?;
self.socket.write_all(msg.as_bytes()).await?;
};
w.map_err(|e| {
warn!("write error: {:?}", e);
()
})
}
}
pub struct Ws<'a, const BUF_SIZE: usize = 1024, const RES_HEAD_BUF_SIZE: usize = 256> {
rx: WsRx<'a, BUF_SIZE>,
tx: WsTx<'a, RES_HEAD_BUF_SIZE>,
name: &'a str,
}
impl<'a, const BUF_SIZE: usize, const HEAD_BUF_SIZE: usize> Ws<'a, BUF_SIZE, HEAD_BUF_SIZE> {
pub fn new(
socket: &'a mut TcpSocket,
buf: &'a mut [u8; BUF_SIZE],
head_buf: &'a mut Vec<u8, HEAD_BUF_SIZE>,
name: &'a str,
) -> Self {
let (rx, tx) = socket.split();
Self {
rx: WsRx {
socket: rx,
buf,
last_msg: Instant::MIN,
},
tx: WsTx {
socket: tx,
head_buf,
},
name,
}
}
// Do this often to respond to pings
async fn rcv(&mut self) -> Result<Option<WsMsg>, ()> {
if !self.rx.socket.read_ready().unwrap() {
return Ok(None);
}
let n = match self.rx.socket.read(self.rx.buf).await {
Ok(0) => {
warn!("read EOF");
return Err(());
}
Ok(n) => n,
Err(e) => {
warn!("Socket {}: read error: {:?}", self.name, e);
return Err(());
}
};
if self.rx.buf[0] & 0b1000_0000 == 0 {
warn!("Fragmented ws messages are not supported!");
return Err(());
}
if self.rx.buf[0] & 0b0111_0000 != 0 {
warn!(
"Reserved ws bits are set : {}",
(self.rx.buf[0] >> 4) & 0b0111
);
return Err(());
}
let (length, n_after_length) = match self.rx.buf[1] & 0b0111_1111 {
126 => (
u64::from_le_bytes([0, 0, 0, 0, 0, 0, self.rx.buf[2], self.rx.buf[3]]),
4,
),
127 => (
u64::from_le_bytes(self.rx.buf[2..10].try_into().unwrap()),
10,
),
l => (l as u64, 2),
};
if length > 512 {
warn!("ws payload bigger than 512!");
return Err(());
}
let content = if self.rx.buf[1] & 0b1000_0000 != 0 {
// masked message
if n_after_length + 4 + length as usize > n {
warn!("ws payload smaller than length");
return Err(());
}
let mask_key: [u8; 4] = self.rx.buf[n_after_length..n_after_length + 4]
.try_into()
.unwrap();
for (i, x) in self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize]
.iter_mut()
.enumerate()
{
*x ^= mask_key[i & 0xff];
}
&self.rx.buf[n_after_length + 4..n_after_length + 4 + length as usize]
} else {
if n_after_length + length as usize > n {
warn!("ws payload smaller than length");
return Err(());
}
&self.rx.buf[n_after_length..n_after_length + length as usize]
};
self.rx.last_msg = Instant::now();
match self.rx.buf[0] & 0b0000_1111 {
// Text message
1 => {
let content = from_utf8(&content).map_err(|_| ())?;
info!("Received text : {:?}", content);
Ok(Some(WsMsg::Text(content)))
}
// Ping
9 => {
self.tx.send(WsMsg::Pong(&content)).await?;
Ok(Some(WsMsg::Ping(&content)))
}
// Pong
10 => Ok(Some(WsMsg::Pong(&content))),
c => {
info!("Unknown ws op code (ignoring) : {}", c);
Ok(Some(WsMsg::Unknown(c, &content)))
}
}
}
pub async fn send(&mut self, msg: WsMsg<'a>) -> Result<(), ()> {
self.tx.send(msg).await
}
}