diff --git a/Cargo.toml b/Cargo.toml index 3ea72ae..05cc371 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,11 @@ anyhow = "1.0.79" clap = "4.4.18" clap-num = "1.0.2" clap-verbosity-flag = "2.1.2" -libpt = { version = "0.3.10", features = ["net"] } +libpt = { version = "0.3.11", features = ["net"] } thiserror = "1.0.56" tokio = { version = "1.35.1", features = ["net", "rt", "macros"] } +rustls-pemfile = "2.0.0" +tokio-rustls = "0.25.0" [features] default = ["server"] diff --git a/spammer/src/main.rs b/spammer/src/main.rs index fb742e5..a95f16e 100644 --- a/spammer/src/main.rs +++ b/spammer/src/main.rs @@ -1,6 +1,6 @@ use threadpool::ThreadPool; const MAX: usize = 20; -use std::process::{exit, Command}; +use std::process::Command; fn main() { let pool = ThreadPool::new(MAX); diff --git a/src/common/args.rs b/src/common/args.rs index 9a53635..970d848 100644 --- a/src/common/args.rs +++ b/src/common/args.rs @@ -1,10 +1,10 @@ +use std::path::PathBuf; + use libpt::log::{Level, Logger}; use clap::Parser; use clap_verbosity_flag::{InfoLevel, Verbosity}; -use crate::common::conf::Mode; - /// short about section displayed in help const ABOUT_ROOT: &'static str = r##" Let your hosts play ping pong over the network @@ -46,16 +46,15 @@ pub(crate) struct Cli { #[arg(short, long, default_value_t = false)] pub(crate) server: bool, - // how much threads the server should use - #[cfg(feature = "server")] - #[arg(short, long, default_value_t = 4)] - pub(crate) threads: usize, - - #[arg(short, long, default_value_t = Mode::Tcp, ignore_case = true)] - pub(crate) mode: Mode, - /// Address of the server pub(crate) addr: std::net::SocketAddr, + + #[cfg(feature = "server")] + #[arg(short, long)] + pub key: PathBuf, + #[cfg(feature = "server")] + #[arg(short, long)] + pub certs: PathBuf, } impl Cli { @@ -72,7 +71,7 @@ impl Cli { } }; if cli.meta { - Logger::init(None, Some(ll)).expect("could not initialize Logger"); + Logger::init(None, Some(ll), true).expect("could not initialize Logger"); } else { // less verbose version Logger::init_mini(Some(ll)).expect("could not initialize Logger"); diff --git a/src/common/conf.rs b/src/common/conf.rs index af19358..edaef7f 100644 --- a/src/common/conf.rs +++ b/src/common/conf.rs @@ -1,68 +1,30 @@ use crate::common::args::Cli; use clap::ValueEnum; -use std::{fmt::Display, time::Duration}; +use std::{fmt::Display, path::PathBuf, time::Duration}; const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms const DEFAULT_DELAY_LEN: u64 = 500; // ms const DEFAULT_WIN_AFTER: usize = 20; -#[derive(Debug, Clone, Copy)] -pub enum Mode { - Tcp, - Tls, -} - -impl ValueEnum for Mode { - fn to_possible_value(&self) -> Option { - Some(match self { - Self::Tcp => clap::builder::PossibleValue::new("tcp"), - Self::Tls => clap::builder::PossibleValue::new("tls"), - }) - } - fn value_variants<'a>() -> &'a [Self] { - &[Self::Tcp] - } - fn from_str(input: &str, ignore_case: bool) -> Result { - let comp: String = if ignore_case { - input.to_lowercase() - } else { - input.to_string() - }; - match comp.as_str() { - "tcp" => return Ok(Self::Tcp), - _ => return Err(format!("\"{input}\" is not a valid mode")), - } - } -} - -impl Display for Mode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let repr: String = match self { - Self::Tcp => format!("tcp"), - Self::Tls => format!("tls"), - }; - write!(f, "{}", repr) - } -} - +#[derive(Clone)] pub struct Config { pub addr: std::net::SocketAddr, - pub mode: Mode, - pub threads: usize, pub timeout: Duration, pub delay: Duration, pub win_after: usize, + pub key: PathBuf, + pub certs: PathBuf, } impl Config { pub fn new(cli: &Cli) -> Self { Config { addr: cli.addr.clone(), - mode: cli.mode.clone(), - threads: cli.threads, timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN), delay: Duration::from_millis(DEFAULT_DELAY_LEN), win_after: DEFAULT_WIN_AFTER, + key: cli.key.clone(), + certs: cli.certs.clone(), } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 9fd79ba..aad6334 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,39 +1,48 @@ #![cfg(feature = "server")] use std::{ - ops::Add, - sync::{atomic::AtomicUsize, Arc}, - time::Duration, + fs::File, net::SocketAddr, sync::{atomic::AtomicUsize, Arc}, time::Duration }; use libpt::log::{debug, info, trace, warn}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls_pemfile::{certs, private_key}; use tokio::{ - io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, - net::{TcpListener, TcpStream}, - time::{self, timeout}, + io::{split, AsyncReadExt, AsyncWriteExt, BufReader}, net::{TcpListener, TcpStream}, time::{self, timeout} }; +use tokio_rustls::{rustls, TlsAcceptor}; use crate::common::conf::Config; pub mod errors; use errors::*; +const BUF_SIZE: usize = 64; + pub struct Server { cfg: Config, pub timeout: Option, server: TcpListener, num_peers: AtomicUsize, + acceptor: TlsAcceptor, } impl Server { pub async fn build(cfg: Config) -> anyhow::Result { + let certs = Self::load_certs(cfg.clone())?; + let key = Self::load_key(cfg.clone())?.unwrap(); + let tls_config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key)?; + let acceptor = TlsAcceptor::from(Arc::new(tls_config)); let server = TcpListener::bind(cfg.addr).await?; let timeout = Some(Duration::from_secs(5)); - let num_peers = AtomicUsize::new(0); + Ok(Server { cfg, timeout, server, - num_peers, + num_peers: AtomicUsize::new(0), + acceptor, }) } pub async fn run(self) -> anyhow::Result<()> { @@ -52,20 +61,22 @@ impl Server { } }); loop { + let (stream, addr) = match rc_self.server.accept().await { + Ok(s) => s, + Err(err) => { + warn!("could not accept stream: {err:?}"); + continue; + } + }; let ref_self = rc_self.clone(); - let (stream, addr) = match ref_self.server.accept().await { - Ok(s) => s, - Err(err) => { - warn!("could not accept stream: {err:?}"); - continue; - } - }; + let acceptor = rc_self.acceptor.clone(); // NOTE: we can only start the task now. If we start it before accepting connections // (so that the task theoretically accepts the connection), we would create endless - // tasks in a loop + // tasks in a loop. tokio::spawn(async move { + let stream: tokio_rustls::server::TlsStream<_> = acceptor.accept(stream).await.unwrap(); ref_self.peer_add(1); - match ref_self.handle_stream(stream).await { + match ref_self.handle_stream(stream, addr).await { Ok(_) => (), Err(err) => match err { ServerError::Timeout(_) => { @@ -81,10 +92,18 @@ impl Server { } } + fn load_key(cfg: Config) -> std::io::Result>> { + private_key(&mut std::io::BufReader::new(File::open(cfg.key)?)) + } + + fn load_certs(cfg: Config) -> std::io::Result>> { + certs(&mut std::io::BufReader::new(File::open(cfg.key)?)).collect() + } + #[inline] fn peer_add(&self, v: usize) { self.num_peers.store( - (self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v), + self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v, std::sync::atomic::Ordering::Relaxed, ) } @@ -92,56 +111,26 @@ impl Server { #[inline] fn peer_sub(&self, v: usize) { self.num_peers.store( - (self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v), + self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v, std::sync::atomic::Ordering::Relaxed, ) } - async fn handle_stream(&self, stream: TcpStream) -> Result<()> { + async fn handle_stream(&self, stream: tokio_rustls::server::TlsStream, addr: SocketAddr) -> Result<()> { let mut pings: usize = 0; - let addr = match stream.peer_addr() { - Ok(a) => a, - Err(err) => { - debug!("could not get peer address: {:?}", err); - return Err(err.into()); - } - }; debug!("new peer: {:?}", addr); - let mut buf = Vec::new(); - let mut reader = BufReader::new(stream); + let mut buf = [0; BUF_SIZE]; + let (mut reader, mut writer) = split(stream); loop { - match self.read(&mut reader, &mut buf).await { - Ok(len) if len == 0 => { - trace!("len is 0, so the stream has ended: {len:?}"); - break; - } - Ok(len) => len, + match reader.read(&mut buf).await { + Ok(len) if len == 0 => { break;}, + Ok(_) => (), Err(err) => { - match err { - ServerError::Timeout(_) => { - debug!("peer {:?} timed out", addr) - } - _ => return Err(err), - } - break; + eprintln!("reader.read err: {err}") } - }; - trace!("received message: {:X?}", buf); - let msg = self.decode(&buf)?; - debug!("< {:?} : {}", addr, msg); - if msg.contains("ping") { - pings += 1; } - if pings < self.cfg.win_after { - reader.write_all(b"pong\0").await?; - debug!("> {:?} : pong", addr,); - } else { - reader.write_all(b"you win!\0").await?; - debug!("> {:?} : you win!", addr,); - reader.shutdown().await?; - break; - } - buf.clear(); + + writer.write(b"pong\0").await?; // we should wait, so that we don't spam the client std::thread::sleep(self.cfg.delay); @@ -154,12 +143,4 @@ impl Server { fn decode(&self, buf: &Vec) -> Result { Ok(String::from_utf8(buf.clone())?.replace('\n', "\\n")) } - - #[inline] - async fn read(&self, reader: &mut BufReader, buf: &mut Vec) -> Result { - match timeout(self.cfg.timeout, reader.read_until(0x00, buf)).await? { - Ok(len) => Ok(len), - Err(err) => Err(err.into()), - } - } }