very primitive tls server
cargo devel CI / cargo CI (push) Successful in 2m21s Details

This commit is contained in:
Christoph J. Scherr 2024-01-23 18:01:41 +01:00
parent dc407c8ce8
commit e622a5cc99
Signed by: cscherrNT
GPG Key ID: 8E2B45BC51A27EA7
5 changed files with 67 additions and 123 deletions

View File

@ -18,9 +18,11 @@ anyhow = "1.0.79"
clap = "4.4.18" clap = "4.4.18"
clap-num = "1.0.2" clap-num = "1.0.2"
clap-verbosity-flag = "2.1.2" clap-verbosity-flag = "2.1.2"
libpt = { version = "0.3.10", features = ["net"] } libpt = { version = "0.3.11", features = ["net"] }
thiserror = "1.0.56" thiserror = "1.0.56"
tokio = { version = "1.35.1", features = ["net", "rt", "macros"] } tokio = { version = "1.35.1", features = ["net", "rt", "macros"] }
rustls-pemfile = "2.0.0"
tokio-rustls = "0.25.0"
[features] [features]
default = ["server"] default = ["server"]

View File

@ -1,6 +1,6 @@
use threadpool::ThreadPool; use threadpool::ThreadPool;
const MAX: usize = 20; const MAX: usize = 20;
use std::process::{exit, Command}; use std::process::Command;
fn main() { fn main() {
let pool = ThreadPool::new(MAX); let pool = ThreadPool::new(MAX);

View File

@ -1,10 +1,10 @@
use std::path::PathBuf;
use libpt::log::{Level, Logger}; use libpt::log::{Level, Logger};
use clap::Parser; use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity}; use clap_verbosity_flag::{InfoLevel, Verbosity};
use crate::common::conf::Mode;
/// short about section displayed in help /// short about section displayed in help
const ABOUT_ROOT: &'static str = r##" const ABOUT_ROOT: &'static str = r##"
Let your hosts play ping pong over the network Let your hosts play ping pong over the network
@ -46,16 +46,15 @@ pub(crate) struct Cli {
#[arg(short, long, default_value_t = false)] #[arg(short, long, default_value_t = false)]
pub(crate) server: bool, pub(crate) server: bool,
// how much threads the server should use
#[cfg(feature = "server")]
#[arg(short, long, default_value_t = 4)]
pub(crate) threads: usize,
#[arg(short, long, default_value_t = Mode::Tcp, ignore_case = true)]
pub(crate) mode: Mode,
/// Address of the server /// Address of the server
pub(crate) addr: std::net::SocketAddr, pub(crate) addr: std::net::SocketAddr,
#[cfg(feature = "server")]
#[arg(short, long)]
pub key: PathBuf,
#[cfg(feature = "server")]
#[arg(short, long)]
pub certs: PathBuf,
} }
impl Cli { impl Cli {
@ -72,7 +71,7 @@ impl Cli {
} }
}; };
if cli.meta { if cli.meta {
Logger::init(None, Some(ll)).expect("could not initialize Logger"); Logger::init(None, Some(ll), true).expect("could not initialize Logger");
} else { } else {
// less verbose version // less verbose version
Logger::init_mini(Some(ll)).expect("could not initialize Logger"); Logger::init_mini(Some(ll)).expect("could not initialize Logger");

View File

@ -1,68 +1,30 @@
use crate::common::args::Cli; use crate::common::args::Cli;
use clap::ValueEnum; use clap::ValueEnum;
use std::{fmt::Display, time::Duration}; use std::{fmt::Display, path::PathBuf, time::Duration};
const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms
const DEFAULT_DELAY_LEN: u64 = 500; // ms const DEFAULT_DELAY_LEN: u64 = 500; // ms
const DEFAULT_WIN_AFTER: usize = 20; const DEFAULT_WIN_AFTER: usize = 20;
#[derive(Debug, Clone, Copy)] #[derive(Clone)]
pub enum Mode {
Tcp,
Tls,
}
impl ValueEnum for Mode {
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
Some(match self {
Self::Tcp => clap::builder::PossibleValue::new("tcp"),
Self::Tls => clap::builder::PossibleValue::new("tls"),
})
}
fn value_variants<'a>() -> &'a [Self] {
&[Self::Tcp]
}
fn from_str(input: &str, ignore_case: bool) -> Result<Self, String> {
let comp: String = if ignore_case {
input.to_lowercase()
} else {
input.to_string()
};
match comp.as_str() {
"tcp" => return Ok(Self::Tcp),
_ => return Err(format!("\"{input}\" is not a valid mode")),
}
}
}
impl Display for Mode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let repr: String = match self {
Self::Tcp => format!("tcp"),
Self::Tls => format!("tls"),
};
write!(f, "{}", repr)
}
}
pub struct Config { pub struct Config {
pub addr: std::net::SocketAddr, pub addr: std::net::SocketAddr,
pub mode: Mode,
pub threads: usize,
pub timeout: Duration, pub timeout: Duration,
pub delay: Duration, pub delay: Duration,
pub win_after: usize, pub win_after: usize,
pub key: PathBuf,
pub certs: PathBuf,
} }
impl Config { impl Config {
pub fn new(cli: &Cli) -> Self { pub fn new(cli: &Cli) -> Self {
Config { Config {
addr: cli.addr.clone(), addr: cli.addr.clone(),
mode: cli.mode.clone(),
threads: cli.threads,
timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN), timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN),
delay: Duration::from_millis(DEFAULT_DELAY_LEN), delay: Duration::from_millis(DEFAULT_DELAY_LEN),
win_after: DEFAULT_WIN_AFTER, win_after: DEFAULT_WIN_AFTER,
key: cli.key.clone(),
certs: cli.certs.clone(),
} }
} }
} }

View File

@ -1,39 +1,48 @@
#![cfg(feature = "server")] #![cfg(feature = "server")]
use std::{ use std::{
ops::Add, fs::File, net::SocketAddr, sync::{atomic::AtomicUsize, Arc}, time::Duration
sync::{atomic::AtomicUsize, Arc},
time::Duration,
}; };
use libpt::log::{debug, info, trace, warn}; use libpt::log::{debug, info, trace, warn};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls_pemfile::{certs, private_key};
use tokio::{ use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, io::{split, AsyncReadExt, AsyncWriteExt, BufReader}, net::{TcpListener, TcpStream}, time::{self, timeout}
net::{TcpListener, TcpStream},
time::{self, timeout},
}; };
use tokio_rustls::{rustls, TlsAcceptor};
use crate::common::conf::Config; use crate::common::conf::Config;
pub mod errors; pub mod errors;
use errors::*; use errors::*;
const BUF_SIZE: usize = 64;
pub struct Server { pub struct Server {
cfg: Config, cfg: Config,
pub timeout: Option<Duration>, pub timeout: Option<Duration>,
server: TcpListener, server: TcpListener,
num_peers: AtomicUsize, num_peers: AtomicUsize,
acceptor: TlsAcceptor,
} }
impl Server { impl Server {
pub async fn build(cfg: Config) -> anyhow::Result<Self> { pub async fn build(cfg: Config) -> anyhow::Result<Self> {
let certs = Self::load_certs(cfg.clone())?;
let key = Self::load_key(cfg.clone())?.unwrap();
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let server = TcpListener::bind(cfg.addr).await?; let server = TcpListener::bind(cfg.addr).await?;
let timeout = Some(Duration::from_secs(5)); let timeout = Some(Duration::from_secs(5));
let num_peers = AtomicUsize::new(0);
Ok(Server { Ok(Server {
cfg, cfg,
timeout, timeout,
server, server,
num_peers, num_peers: AtomicUsize::new(0),
acceptor,
}) })
} }
pub async fn run(self) -> anyhow::Result<()> { pub async fn run(self) -> anyhow::Result<()> {
@ -52,20 +61,22 @@ impl Server {
} }
}); });
loop { loop {
let ref_self = rc_self.clone(); let (stream, addr) = match rc_self.server.accept().await {
let (stream, addr) = match ref_self.server.accept().await {
Ok(s) => s, Ok(s) => s,
Err(err) => { Err(err) => {
warn!("could not accept stream: {err:?}"); warn!("could not accept stream: {err:?}");
continue; continue;
} }
}; };
let ref_self = rc_self.clone();
let acceptor = rc_self.acceptor.clone();
// NOTE: we can only start the task now. If we start it before accepting connections // NOTE: we can only start the task now. If we start it before accepting connections
// (so that the task theoretically accepts the connection), we would create endless // (so that the task theoretically accepts the connection), we would create endless
// tasks in a loop // tasks in a loop.
tokio::spawn(async move { tokio::spawn(async move {
let stream: tokio_rustls::server::TlsStream<_> = acceptor.accept(stream).await.unwrap();
ref_self.peer_add(1); ref_self.peer_add(1);
match ref_self.handle_stream(stream).await { match ref_self.handle_stream(stream, addr).await {
Ok(_) => (), Ok(_) => (),
Err(err) => match err { Err(err) => match err {
ServerError::Timeout(_) => { ServerError::Timeout(_) => {
@ -81,10 +92,18 @@ impl Server {
} }
} }
fn load_key(cfg: Config) -> std::io::Result<Option<PrivateKeyDer<'static>>> {
private_key(&mut std::io::BufReader::new(File::open(cfg.key)?))
}
fn load_certs(cfg: Config) -> std::io::Result<Vec<CertificateDer<'static>>> {
certs(&mut std::io::BufReader::new(File::open(cfg.key)?)).collect()
}
#[inline] #[inline]
fn peer_add(&self, v: usize) { fn peer_add(&self, v: usize) {
self.num_peers.store( self.num_peers.store(
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v), self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v,
std::sync::atomic::Ordering::Relaxed, std::sync::atomic::Ordering::Relaxed,
) )
} }
@ -92,56 +111,26 @@ impl Server {
#[inline] #[inline]
fn peer_sub(&self, v: usize) { fn peer_sub(&self, v: usize) {
self.num_peers.store( self.num_peers.store(
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v), self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v,
std::sync::atomic::Ordering::Relaxed, std::sync::atomic::Ordering::Relaxed,
) )
} }
async fn handle_stream(&self, stream: TcpStream) -> Result<()> { async fn handle_stream(&self, stream: tokio_rustls::server::TlsStream<TcpStream>, addr: SocketAddr) -> Result<()> {
let mut pings: usize = 0; let mut pings: usize = 0;
let addr = match stream.peer_addr() {
Ok(a) => a,
Err(err) => {
debug!("could not get peer address: {:?}", err);
return Err(err.into());
}
};
debug!("new peer: {:?}", addr); debug!("new peer: {:?}", addr);
let mut buf = Vec::new(); let mut buf = [0; BUF_SIZE];
let mut reader = BufReader::new(stream); let (mut reader, mut writer) = split(stream);
loop { loop {
match self.read(&mut reader, &mut buf).await { match reader.read(&mut buf).await {
Ok(len) if len == 0 => { Ok(len) if len == 0 => { break;},
trace!("len is 0, so the stream has ended: {len:?}"); Ok(_) => (),
break;
}
Ok(len) => len,
Err(err) => { Err(err) => {
match err { eprintln!("reader.read err: {err}")
ServerError::Timeout(_) => {
debug!("peer {:?} timed out", addr)
} }
_ => return Err(err),
} }
break;
} writer.write(b"pong\0").await?;
};
trace!("received message: {:X?}", buf);
let msg = self.decode(&buf)?;
debug!("< {:?} : {}", addr, msg);
if msg.contains("ping") {
pings += 1;
}
if pings < self.cfg.win_after {
reader.write_all(b"pong\0").await?;
debug!("> {:?} : pong", addr,);
} else {
reader.write_all(b"you win!\0").await?;
debug!("> {:?} : you win!", addr,);
reader.shutdown().await?;
break;
}
buf.clear();
// we should wait, so that we don't spam the client // we should wait, so that we don't spam the client
std::thread::sleep(self.cfg.delay); std::thread::sleep(self.cfg.delay);
@ -154,12 +143,4 @@ impl Server {
fn decode(&self, buf: &Vec<u8>) -> Result<String> { fn decode(&self, buf: &Vec<u8>) -> Result<String> {
Ok(String::from_utf8(buf.clone())?.replace('\n', "\\n")) Ok(String::from_utf8(buf.clone())?.replace('\n', "\\n"))
} }
#[inline]
async fn read(&self, reader: &mut BufReader<TcpStream>, buf: &mut Vec<u8>) -> Result<usize> {
match timeout(self.cfg.timeout, reader.read_until(0x00, buf)).await? {
Ok(len) => Ok(len),
Err(err) => Err(err.into()),
}
}
} }