generated from PlexSheep/rs-base
very primitive tls server
cargo devel CI / cargo CI (push) Successful in 2m21s
Details
cargo devel CI / cargo CI (push) Successful in 2m21s
Details
This commit is contained in:
parent
dc407c8ce8
commit
e622a5cc99
|
@ -18,9 +18,11 @@ anyhow = "1.0.79"
|
|||
clap = "4.4.18"
|
||||
clap-num = "1.0.2"
|
||||
clap-verbosity-flag = "2.1.2"
|
||||
libpt = { version = "0.3.10", features = ["net"] }
|
||||
libpt = { version = "0.3.11", features = ["net"] }
|
||||
thiserror = "1.0.56"
|
||||
tokio = { version = "1.35.1", features = ["net", "rt", "macros"] }
|
||||
rustls-pemfile = "2.0.0"
|
||||
tokio-rustls = "0.25.0"
|
||||
|
||||
[features]
|
||||
default = ["server"]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use threadpool::ThreadPool;
|
||||
const MAX: usize = 20;
|
||||
use std::process::{exit, Command};
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
let pool = ThreadPool::new(MAX);
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use libpt::log::{Level, Logger};
|
||||
|
||||
use clap::Parser;
|
||||
use clap_verbosity_flag::{InfoLevel, Verbosity};
|
||||
|
||||
use crate::common::conf::Mode;
|
||||
|
||||
/// short about section displayed in help
|
||||
const ABOUT_ROOT: &'static str = r##"
|
||||
Let your hosts play ping pong over the network
|
||||
|
@ -46,16 +46,15 @@ pub(crate) struct Cli {
|
|||
#[arg(short, long, default_value_t = false)]
|
||||
pub(crate) server: bool,
|
||||
|
||||
// how much threads the server should use
|
||||
#[cfg(feature = "server")]
|
||||
#[arg(short, long, default_value_t = 4)]
|
||||
pub(crate) threads: usize,
|
||||
|
||||
#[arg(short, long, default_value_t = Mode::Tcp, ignore_case = true)]
|
||||
pub(crate) mode: Mode,
|
||||
|
||||
/// Address of the server
|
||||
pub(crate) addr: std::net::SocketAddr,
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
#[arg(short, long)]
|
||||
pub key: PathBuf,
|
||||
#[cfg(feature = "server")]
|
||||
#[arg(short, long)]
|
||||
pub certs: PathBuf,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
|
@ -72,7 +71,7 @@ impl Cli {
|
|||
}
|
||||
};
|
||||
if cli.meta {
|
||||
Logger::init(None, Some(ll)).expect("could not initialize Logger");
|
||||
Logger::init(None, Some(ll), true).expect("could not initialize Logger");
|
||||
} else {
|
||||
// less verbose version
|
||||
Logger::init_mini(Some(ll)).expect("could not initialize Logger");
|
||||
|
|
|
@ -1,68 +1,30 @@
|
|||
use crate::common::args::Cli;
|
||||
use clap::ValueEnum;
|
||||
use std::{fmt::Display, time::Duration};
|
||||
use std::{fmt::Display, path::PathBuf, time::Duration};
|
||||
|
||||
const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms
|
||||
const DEFAULT_DELAY_LEN: u64 = 500; // ms
|
||||
const DEFAULT_WIN_AFTER: usize = 20;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Mode {
|
||||
Tcp,
|
||||
Tls,
|
||||
}
|
||||
|
||||
impl ValueEnum for Mode {
|
||||
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
|
||||
Some(match self {
|
||||
Self::Tcp => clap::builder::PossibleValue::new("tcp"),
|
||||
Self::Tls => clap::builder::PossibleValue::new("tls"),
|
||||
})
|
||||
}
|
||||
fn value_variants<'a>() -> &'a [Self] {
|
||||
&[Self::Tcp]
|
||||
}
|
||||
fn from_str(input: &str, ignore_case: bool) -> Result<Self, String> {
|
||||
let comp: String = if ignore_case {
|
||||
input.to_lowercase()
|
||||
} else {
|
||||
input.to_string()
|
||||
};
|
||||
match comp.as_str() {
|
||||
"tcp" => return Ok(Self::Tcp),
|
||||
_ => return Err(format!("\"{input}\" is not a valid mode")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Mode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let repr: String = match self {
|
||||
Self::Tcp => format!("tcp"),
|
||||
Self::Tls => format!("tls"),
|
||||
};
|
||||
write!(f, "{}", repr)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
pub addr: std::net::SocketAddr,
|
||||
pub mode: Mode,
|
||||
pub threads: usize,
|
||||
pub timeout: Duration,
|
||||
pub delay: Duration,
|
||||
pub win_after: usize,
|
||||
pub key: PathBuf,
|
||||
pub certs: PathBuf,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(cli: &Cli) -> Self {
|
||||
Config {
|
||||
addr: cli.addr.clone(),
|
||||
mode: cli.mode.clone(),
|
||||
threads: cli.threads,
|
||||
timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN),
|
||||
delay: Duration::from_millis(DEFAULT_DELAY_LEN),
|
||||
win_after: DEFAULT_WIN_AFTER,
|
||||
key: cli.key.clone(),
|
||||
certs: cli.certs.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,39 +1,48 @@
|
|||
#![cfg(feature = "server")]
|
||||
use std::{
|
||||
ops::Add,
|
||||
sync::{atomic::AtomicUsize, Arc},
|
||||
time::Duration,
|
||||
fs::File, net::SocketAddr, sync::{atomic::AtomicUsize, Arc}, time::Duration
|
||||
};
|
||||
|
||||
use libpt::log::{debug, info, trace, warn};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls_pemfile::{certs, private_key};
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||
net::{TcpListener, TcpStream},
|
||||
time::{self, timeout},
|
||||
io::{split, AsyncReadExt, AsyncWriteExt, BufReader}, net::{TcpListener, TcpStream}, time::{self, timeout}
|
||||
};
|
||||
use tokio_rustls::{rustls, TlsAcceptor};
|
||||
|
||||
use crate::common::conf::Config;
|
||||
|
||||
pub mod errors;
|
||||
use errors::*;
|
||||
|
||||
const BUF_SIZE: usize = 64;
|
||||
|
||||
pub struct Server {
|
||||
cfg: Config,
|
||||
pub timeout: Option<Duration>,
|
||||
server: TcpListener,
|
||||
num_peers: AtomicUsize,
|
||||
acceptor: TlsAcceptor,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub async fn build(cfg: Config) -> anyhow::Result<Self> {
|
||||
let certs = Self::load_certs(cfg.clone())?;
|
||||
let key = Self::load_key(cfg.clone())?.unwrap();
|
||||
let tls_config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
|
||||
let server = TcpListener::bind(cfg.addr).await?;
|
||||
let timeout = Some(Duration::from_secs(5));
|
||||
let num_peers = AtomicUsize::new(0);
|
||||
|
||||
Ok(Server {
|
||||
cfg,
|
||||
timeout,
|
||||
server,
|
||||
num_peers,
|
||||
num_peers: AtomicUsize::new(0),
|
||||
acceptor,
|
||||
})
|
||||
}
|
||||
pub async fn run(self) -> anyhow::Result<()> {
|
||||
|
@ -52,20 +61,22 @@ impl Server {
|
|||
}
|
||||
});
|
||||
loop {
|
||||
let ref_self = rc_self.clone();
|
||||
let (stream, addr) = match ref_self.server.accept().await {
|
||||
let (stream, addr) = match rc_self.server.accept().await {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
warn!("could not accept stream: {err:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let ref_self = rc_self.clone();
|
||||
let acceptor = rc_self.acceptor.clone();
|
||||
// NOTE: we can only start the task now. If we start it before accepting connections
|
||||
// (so that the task theoretically accepts the connection), we would create endless
|
||||
// tasks in a loop
|
||||
// tasks in a loop.
|
||||
tokio::spawn(async move {
|
||||
let stream: tokio_rustls::server::TlsStream<_> = acceptor.accept(stream).await.unwrap();
|
||||
ref_self.peer_add(1);
|
||||
match ref_self.handle_stream(stream).await {
|
||||
match ref_self.handle_stream(stream, addr).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => match err {
|
||||
ServerError::Timeout(_) => {
|
||||
|
@ -81,10 +92,18 @@ impl Server {
|
|||
}
|
||||
}
|
||||
|
||||
fn load_key(cfg: Config) -> std::io::Result<Option<PrivateKeyDer<'static>>> {
|
||||
private_key(&mut std::io::BufReader::new(File::open(cfg.key)?))
|
||||
}
|
||||
|
||||
fn load_certs(cfg: Config) -> std::io::Result<Vec<CertificateDer<'static>>> {
|
||||
certs(&mut std::io::BufReader::new(File::open(cfg.key)?)).collect()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn peer_add(&self, v: usize) {
|
||||
self.num_peers.store(
|
||||
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v),
|
||||
self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
)
|
||||
}
|
||||
|
@ -92,56 +111,26 @@ impl Server {
|
|||
#[inline]
|
||||
fn peer_sub(&self, v: usize) {
|
||||
self.num_peers.store(
|
||||
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v),
|
||||
self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_stream(&self, stream: TcpStream) -> Result<()> {
|
||||
async fn handle_stream(&self, stream: tokio_rustls::server::TlsStream<TcpStream>, addr: SocketAddr) -> Result<()> {
|
||||
let mut pings: usize = 0;
|
||||
let addr = match stream.peer_addr() {
|
||||
Ok(a) => a,
|
||||
Err(err) => {
|
||||
debug!("could not get peer address: {:?}", err);
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
debug!("new peer: {:?}", addr);
|
||||
let mut buf = Vec::new();
|
||||
let mut reader = BufReader::new(stream);
|
||||
let mut buf = [0; BUF_SIZE];
|
||||
let (mut reader, mut writer) = split(stream);
|
||||
loop {
|
||||
match self.read(&mut reader, &mut buf).await {
|
||||
Ok(len) if len == 0 => {
|
||||
trace!("len is 0, so the stream has ended: {len:?}");
|
||||
break;
|
||||
}
|
||||
Ok(len) => len,
|
||||
match reader.read(&mut buf).await {
|
||||
Ok(len) if len == 0 => { break;},
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
match err {
|
||||
ServerError::Timeout(_) => {
|
||||
debug!("peer {:?} timed out", addr)
|
||||
eprintln!("reader.read err: {err}")
|
||||
}
|
||||
_ => return Err(err),
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
trace!("received message: {:X?}", buf);
|
||||
let msg = self.decode(&buf)?;
|
||||
debug!("< {:?} : {}", addr, msg);
|
||||
if msg.contains("ping") {
|
||||
pings += 1;
|
||||
}
|
||||
if pings < self.cfg.win_after {
|
||||
reader.write_all(b"pong\0").await?;
|
||||
debug!("> {:?} : pong", addr,);
|
||||
} else {
|
||||
reader.write_all(b"you win!\0").await?;
|
||||
debug!("> {:?} : you win!", addr,);
|
||||
reader.shutdown().await?;
|
||||
break;
|
||||
}
|
||||
buf.clear();
|
||||
|
||||
writer.write(b"pong\0").await?;
|
||||
|
||||
// we should wait, so that we don't spam the client
|
||||
std::thread::sleep(self.cfg.delay);
|
||||
|
@ -154,12 +143,4 @@ impl Server {
|
|||
fn decode(&self, buf: &Vec<u8>) -> Result<String> {
|
||||
Ok(String::from_utf8(buf.clone())?.replace('\n', "\\n"))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
async fn read(&self, reader: &mut BufReader<TcpStream>, buf: &mut Vec<u8>) -> Result<usize> {
|
||||
match timeout(self.cfg.timeout, reader.read_until(0x00, buf)).await? {
|
||||
Ok(len) => Ok(len),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue