very primitive tls server
cargo devel CI / cargo CI (push) Successful in 2m21s Details

This commit is contained in:
Christoph J. Scherr 2024-01-23 18:01:41 +01:00
parent dc407c8ce8
commit e622a5cc99
Signed by: cscherrNT
GPG Key ID: 8E2B45BC51A27EA7
5 changed files with 67 additions and 123 deletions

View File

@ -18,9 +18,11 @@ anyhow = "1.0.79"
clap = "4.4.18"
clap-num = "1.0.2"
clap-verbosity-flag = "2.1.2"
libpt = { version = "0.3.10", features = ["net"] }
libpt = { version = "0.3.11", features = ["net"] }
thiserror = "1.0.56"
tokio = { version = "1.35.1", features = ["net", "rt", "macros"] }
rustls-pemfile = "2.0.0"
tokio-rustls = "0.25.0"
[features]
default = ["server"]

View File

@ -1,6 +1,6 @@
use threadpool::ThreadPool;
const MAX: usize = 20;
use std::process::{exit, Command};
use std::process::Command;
fn main() {
let pool = ThreadPool::new(MAX);

View File

@ -1,10 +1,10 @@
use std::path::PathBuf;
use libpt::log::{Level, Logger};
use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity};
use crate::common::conf::Mode;
/// short about section displayed in help
const ABOUT_ROOT: &'static str = r##"
Let your hosts play ping pong over the network
@ -46,16 +46,15 @@ pub(crate) struct Cli {
#[arg(short, long, default_value_t = false)]
pub(crate) server: bool,
// how much threads the server should use
#[cfg(feature = "server")]
#[arg(short, long, default_value_t = 4)]
pub(crate) threads: usize,
#[arg(short, long, default_value_t = Mode::Tcp, ignore_case = true)]
pub(crate) mode: Mode,
/// Address of the server
pub(crate) addr: std::net::SocketAddr,
#[cfg(feature = "server")]
#[arg(short, long)]
pub key: PathBuf,
#[cfg(feature = "server")]
#[arg(short, long)]
pub certs: PathBuf,
}
impl Cli {
@ -72,7 +71,7 @@ impl Cli {
}
};
if cli.meta {
Logger::init(None, Some(ll)).expect("could not initialize Logger");
Logger::init(None, Some(ll), true).expect("could not initialize Logger");
} else {
// less verbose version
Logger::init_mini(Some(ll)).expect("could not initialize Logger");

View File

@ -1,68 +1,30 @@
use crate::common::args::Cli;
use clap::ValueEnum;
use std::{fmt::Display, time::Duration};
use std::{fmt::Display, path::PathBuf, time::Duration};
const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms
const DEFAULT_DELAY_LEN: u64 = 500; // ms
const DEFAULT_WIN_AFTER: usize = 20;
#[derive(Debug, Clone, Copy)]
pub enum Mode {
Tcp,
Tls,
}
impl ValueEnum for Mode {
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
Some(match self {
Self::Tcp => clap::builder::PossibleValue::new("tcp"),
Self::Tls => clap::builder::PossibleValue::new("tls"),
})
}
fn value_variants<'a>() -> &'a [Self] {
&[Self::Tcp]
}
fn from_str(input: &str, ignore_case: bool) -> Result<Self, String> {
let comp: String = if ignore_case {
input.to_lowercase()
} else {
input.to_string()
};
match comp.as_str() {
"tcp" => return Ok(Self::Tcp),
_ => return Err(format!("\"{input}\" is not a valid mode")),
}
}
}
impl Display for Mode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let repr: String = match self {
Self::Tcp => format!("tcp"),
Self::Tls => format!("tls"),
};
write!(f, "{}", repr)
}
}
#[derive(Clone)]
pub struct Config {
pub addr: std::net::SocketAddr,
pub mode: Mode,
pub threads: usize,
pub timeout: Duration,
pub delay: Duration,
pub win_after: usize,
pub key: PathBuf,
pub certs: PathBuf,
}
impl Config {
pub fn new(cli: &Cli) -> Self {
Config {
addr: cli.addr.clone(),
mode: cli.mode.clone(),
threads: cli.threads,
timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN),
delay: Duration::from_millis(DEFAULT_DELAY_LEN),
win_after: DEFAULT_WIN_AFTER,
key: cli.key.clone(),
certs: cli.certs.clone(),
}
}
}

View File

@ -1,39 +1,48 @@
#![cfg(feature = "server")]
use std::{
ops::Add,
sync::{atomic::AtomicUsize, Arc},
time::Duration,
fs::File, net::SocketAddr, sync::{atomic::AtomicUsize, Arc}, time::Duration
};
use libpt::log::{debug, info, trace, warn};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls_pemfile::{certs, private_key};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{TcpListener, TcpStream},
time::{self, timeout},
io::{split, AsyncReadExt, AsyncWriteExt, BufReader}, net::{TcpListener, TcpStream}, time::{self, timeout}
};
use tokio_rustls::{rustls, TlsAcceptor};
use crate::common::conf::Config;
pub mod errors;
use errors::*;
const BUF_SIZE: usize = 64;
pub struct Server {
cfg: Config,
pub timeout: Option<Duration>,
server: TcpListener,
num_peers: AtomicUsize,
acceptor: TlsAcceptor,
}
impl Server {
pub async fn build(cfg: Config) -> anyhow::Result<Self> {
let certs = Self::load_certs(cfg.clone())?;
let key = Self::load_key(cfg.clone())?.unwrap();
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let server = TcpListener::bind(cfg.addr).await?;
let timeout = Some(Duration::from_secs(5));
let num_peers = AtomicUsize::new(0);
Ok(Server {
cfg,
timeout,
server,
num_peers,
num_peers: AtomicUsize::new(0),
acceptor,
})
}
pub async fn run(self) -> anyhow::Result<()> {
@ -52,20 +61,22 @@ impl Server {
}
});
loop {
let (stream, addr) = match rc_self.server.accept().await {
Ok(s) => s,
Err(err) => {
warn!("could not accept stream: {err:?}");
continue;
}
};
let ref_self = rc_self.clone();
let (stream, addr) = match ref_self.server.accept().await {
Ok(s) => s,
Err(err) => {
warn!("could not accept stream: {err:?}");
continue;
}
};
let acceptor = rc_self.acceptor.clone();
// NOTE: we can only start the task now. If we start it before accepting connections
// (so that the task theoretically accepts the connection), we would create endless
// tasks in a loop
// tasks in a loop.
tokio::spawn(async move {
let stream: tokio_rustls::server::TlsStream<_> = acceptor.accept(stream).await.unwrap();
ref_self.peer_add(1);
match ref_self.handle_stream(stream).await {
match ref_self.handle_stream(stream, addr).await {
Ok(_) => (),
Err(err) => match err {
ServerError::Timeout(_) => {
@ -81,10 +92,18 @@ impl Server {
}
}
fn load_key(cfg: Config) -> std::io::Result<Option<PrivateKeyDer<'static>>> {
private_key(&mut std::io::BufReader::new(File::open(cfg.key)?))
}
fn load_certs(cfg: Config) -> std::io::Result<Vec<CertificateDer<'static>>> {
certs(&mut std::io::BufReader::new(File::open(cfg.key)?)).collect()
}
#[inline]
fn peer_add(&self, v: usize) {
self.num_peers.store(
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v),
self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v,
std::sync::atomic::Ordering::Relaxed,
)
}
@ -92,56 +111,26 @@ impl Server {
#[inline]
fn peer_sub(&self, v: usize) {
self.num_peers.store(
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v),
self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v,
std::sync::atomic::Ordering::Relaxed,
)
}
async fn handle_stream(&self, stream: TcpStream) -> Result<()> {
async fn handle_stream(&self, stream: tokio_rustls::server::TlsStream<TcpStream>, addr: SocketAddr) -> Result<()> {
let mut pings: usize = 0;
let addr = match stream.peer_addr() {
Ok(a) => a,
Err(err) => {
debug!("could not get peer address: {:?}", err);
return Err(err.into());
}
};
debug!("new peer: {:?}", addr);
let mut buf = Vec::new();
let mut reader = BufReader::new(stream);
let mut buf = [0; BUF_SIZE];
let (mut reader, mut writer) = split(stream);
loop {
match self.read(&mut reader, &mut buf).await {
Ok(len) if len == 0 => {
trace!("len is 0, so the stream has ended: {len:?}");
break;
}
Ok(len) => len,
match reader.read(&mut buf).await {
Ok(len) if len == 0 => { break;},
Ok(_) => (),
Err(err) => {
match err {
ServerError::Timeout(_) => {
debug!("peer {:?} timed out", addr)
}
_ => return Err(err),
}
break;
eprintln!("reader.read err: {err}")
}
};
trace!("received message: {:X?}", buf);
let msg = self.decode(&buf)?;
debug!("< {:?} : {}", addr, msg);
if msg.contains("ping") {
pings += 1;
}
if pings < self.cfg.win_after {
reader.write_all(b"pong\0").await?;
debug!("> {:?} : pong", addr,);
} else {
reader.write_all(b"you win!\0").await?;
debug!("> {:?} : you win!", addr,);
reader.shutdown().await?;
break;
}
buf.clear();
writer.write(b"pong\0").await?;
// we should wait, so that we don't spam the client
std::thread::sleep(self.cfg.delay);
@ -154,12 +143,4 @@ impl Server {
fn decode(&self, buf: &Vec<u8>) -> Result<String> {
Ok(String::from_utf8(buf.clone())?.replace('\n', "\\n"))
}
#[inline]
async fn read(&self, reader: &mut BufReader<TcpStream>, buf: &mut Vec<u8>) -> Result<usize> {
match timeout(self.cfg.timeout, reader.read_until(0x00, buf)).await? {
Ok(len) => Ok(len),
Err(err) => Err(err.into()),
}
}
}