#![cfg(feature = "server")] use std::{ fs::File, net::SocketAddr, sync::{atomic::AtomicUsize, Arc}, time::Duration, usize, }; use libpt::log::{debug, error, info, trace, warn}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls_pemfile::{certs, private_key}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, time::{self}, }; use tokio_rustls::{rustls, TlsAcceptor}; use crate::common::{conf::Config, decode}; pub mod errors; use errors::*; const BUF_SIZE: usize = 512; pub struct Server { cfg: Config, server: TcpListener, num_peers: AtomicUsize, acceptor: TlsAcceptor, } impl Server { pub async fn build(cfg: Config) -> anyhow::Result { let certs = Self::load_certs(cfg.clone())?; trace!("loaded certs: {:?}", certs); let key = Self::load_key(cfg.clone())?.expect("bad key?"); trace!("loaded key: {:?}", key); let tls_config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)?; let acceptor = TlsAcceptor::from(Arc::new(tls_config)); let server = TcpListener::bind(cfg.addr).await?; Ok(Server { cfg, server, num_peers: AtomicUsize::new(0), acceptor, }) } pub async fn run(self) -> anyhow::Result<()> { let rc_self = Arc::new(self); let ref_self = rc_self.clone(); tokio::spawn(async move { let mut interval = time::interval(Duration::from_millis(5000)); loop { interval.tick().await; info!( "status: {} peers", ref_self .num_peers .load(std::sync::atomic::Ordering::Relaxed) ); } }); loop { let (stream, addr) = match rc_self.server.accept().await { Ok(s) => s, Err(err) => { warn!("could not accept tcp stream: {err:?}"); continue; } }; let ref_self = rc_self.clone(); // NOTE: we can only start the task now. If we start it before accepting connections // (so that the task theoretically accepts the connection), we would create endless // tasks in a loop. tokio::spawn(async move { let stream: tokio_rustls::server::TlsStream<_> = match ref_self.acceptor.accept(stream).await { Ok(s) => s, Err(err) => { warn!("could not accept tls stream: {err}"); return; } }; ref_self.peer_add(1); match ref_self.handle_stream(stream, addr).await { Ok(_) => (), Err(err) => match err { ServerError::Timeout(_) => { debug!("stream {:?} timed out", addr) } _ => { warn!("error while handling stream: {:?}", err) } }, }; ref_self.peer_sub(1); }); } } fn load_key(cfg: Config) -> std::io::Result>> { if cfg.key.is_none() { error!("the server needs a key!"); return Err(std::io::ErrorKind::InvalidInput.into()); } private_key(&mut std::io::BufReader::new(File::open( cfg.key.clone().unwrap(), )?)) } fn load_certs(cfg: Config) -> std::io::Result>> { if cfg.certs.is_none() { error!("the server needs at least one certificate!"); return Err(std::io::ErrorKind::InvalidInput.into()); } match certs(&mut std::io::BufReader::new(File::open( cfg.certs.clone().unwrap(), )?)) .collect::>>>() { Ok(v) if !v.is_empty() => Ok(v), Ok(_) => { error!("no certs found in provided file {:?}", cfg.certs); Err(std::io::ErrorKind::InvalidInput.into()) } Err(err) => { error!("could not load certs: {err:?}"); Err(err) } } } #[inline] fn peer_add(&self, v: usize) { self.num_peers.store( self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v, std::sync::atomic::Ordering::Relaxed, ) } #[inline] fn peer_sub(&self, v: usize) { self.num_peers.store( self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v, std::sync::atomic::Ordering::Relaxed, ) } async fn handle_stream( &self, mut stream: tokio_rustls::server::TlsStream, addr: SocketAddr, ) -> Result<()> { let mut buf = [0; BUF_SIZE]; let mut pings: usize = 0; debug!("new peer: {:?}", addr); while stream.read(&mut buf).await? > 0 { let request = decode(&buf)?; debug!(pings, "< ({})\n\"{}\"", addr, request); if request == "ping" { pings = pings.saturating_add(1); if (pings > self.cfg.win_after && !self.cfg.indefinitely) || (pings > usize::MAX - 1) { stream.write_all(b"You win!").await?; debug!(pings, "> ({})\n\"{}\"", addr, "You win!"); info!("{} won!", addr); stream.flush().await?; stream.shutdown().await?; break; } let response = format!("pong ({:x})", pings); stream.write_all(response.as_bytes()).await?; debug!(pings, "> ({})\n\"{}\"", addr, response); } else { stream.write_all(b"what is the magic word?").await?; debug!(pings, "> ({})\n\"{}\"", addr, "what is the magic word?"); stream.flush().await?; } // we should wait, so that we don't spam the client std::thread::sleep(self.cfg.delay); } debug!("disconnected peer: {:?}", addr); Ok(()) } }