generated from PlexSheep/rs-base
very primitive tls server
cargo devel CI / cargo CI (push) Successful in 2m21s
Details
cargo devel CI / cargo CI (push) Successful in 2m21s
Details
This commit is contained in:
parent
dc407c8ce8
commit
e622a5cc99
|
@ -18,9 +18,11 @@ anyhow = "1.0.79"
|
||||||
clap = "4.4.18"
|
clap = "4.4.18"
|
||||||
clap-num = "1.0.2"
|
clap-num = "1.0.2"
|
||||||
clap-verbosity-flag = "2.1.2"
|
clap-verbosity-flag = "2.1.2"
|
||||||
libpt = { version = "0.3.10", features = ["net"] }
|
libpt = { version = "0.3.11", features = ["net"] }
|
||||||
thiserror = "1.0.56"
|
thiserror = "1.0.56"
|
||||||
tokio = { version = "1.35.1", features = ["net", "rt", "macros"] }
|
tokio = { version = "1.35.1", features = ["net", "rt", "macros"] }
|
||||||
|
rustls-pemfile = "2.0.0"
|
||||||
|
tokio-rustls = "0.25.0"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["server"]
|
default = ["server"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use threadpool::ThreadPool;
|
use threadpool::ThreadPool;
|
||||||
const MAX: usize = 20;
|
const MAX: usize = 20;
|
||||||
use std::process::{exit, Command};
|
use std::process::Command;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let pool = ThreadPool::new(MAX);
|
let pool = ThreadPool::new(MAX);
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use libpt::log::{Level, Logger};
|
use libpt::log::{Level, Logger};
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use clap_verbosity_flag::{InfoLevel, Verbosity};
|
use clap_verbosity_flag::{InfoLevel, Verbosity};
|
||||||
|
|
||||||
use crate::common::conf::Mode;
|
|
||||||
|
|
||||||
/// short about section displayed in help
|
/// short about section displayed in help
|
||||||
const ABOUT_ROOT: &'static str = r##"
|
const ABOUT_ROOT: &'static str = r##"
|
||||||
Let your hosts play ping pong over the network
|
Let your hosts play ping pong over the network
|
||||||
|
@ -46,16 +46,15 @@ pub(crate) struct Cli {
|
||||||
#[arg(short, long, default_value_t = false)]
|
#[arg(short, long, default_value_t = false)]
|
||||||
pub(crate) server: bool,
|
pub(crate) server: bool,
|
||||||
|
|
||||||
// how much threads the server should use
|
|
||||||
#[cfg(feature = "server")]
|
|
||||||
#[arg(short, long, default_value_t = 4)]
|
|
||||||
pub(crate) threads: usize,
|
|
||||||
|
|
||||||
#[arg(short, long, default_value_t = Mode::Tcp, ignore_case = true)]
|
|
||||||
pub(crate) mode: Mode,
|
|
||||||
|
|
||||||
/// Address of the server
|
/// Address of the server
|
||||||
pub(crate) addr: std::net::SocketAddr,
|
pub(crate) addr: std::net::SocketAddr,
|
||||||
|
|
||||||
|
#[cfg(feature = "server")]
|
||||||
|
#[arg(short, long)]
|
||||||
|
pub key: PathBuf,
|
||||||
|
#[cfg(feature = "server")]
|
||||||
|
#[arg(short, long)]
|
||||||
|
pub certs: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cli {
|
impl Cli {
|
||||||
|
@ -72,7 +71,7 @@ impl Cli {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if cli.meta {
|
if cli.meta {
|
||||||
Logger::init(None, Some(ll)).expect("could not initialize Logger");
|
Logger::init(None, Some(ll), true).expect("could not initialize Logger");
|
||||||
} else {
|
} else {
|
||||||
// less verbose version
|
// less verbose version
|
||||||
Logger::init_mini(Some(ll)).expect("could not initialize Logger");
|
Logger::init_mini(Some(ll)).expect("could not initialize Logger");
|
||||||
|
|
|
@ -1,68 +1,30 @@
|
||||||
use crate::common::args::Cli;
|
use crate::common::args::Cli;
|
||||||
use clap::ValueEnum;
|
use clap::ValueEnum;
|
||||||
use std::{fmt::Display, time::Duration};
|
use std::{fmt::Display, path::PathBuf, time::Duration};
|
||||||
|
|
||||||
const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms
|
const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms
|
||||||
const DEFAULT_DELAY_LEN: u64 = 500; // ms
|
const DEFAULT_DELAY_LEN: u64 = 500; // ms
|
||||||
const DEFAULT_WIN_AFTER: usize = 20;
|
const DEFAULT_WIN_AFTER: usize = 20;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Clone)]
|
||||||
pub enum Mode {
|
|
||||||
Tcp,
|
|
||||||
Tls,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ValueEnum for Mode {
|
|
||||||
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
|
|
||||||
Some(match self {
|
|
||||||
Self::Tcp => clap::builder::PossibleValue::new("tcp"),
|
|
||||||
Self::Tls => clap::builder::PossibleValue::new("tls"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
fn value_variants<'a>() -> &'a [Self] {
|
|
||||||
&[Self::Tcp]
|
|
||||||
}
|
|
||||||
fn from_str(input: &str, ignore_case: bool) -> Result<Self, String> {
|
|
||||||
let comp: String = if ignore_case {
|
|
||||||
input.to_lowercase()
|
|
||||||
} else {
|
|
||||||
input.to_string()
|
|
||||||
};
|
|
||||||
match comp.as_str() {
|
|
||||||
"tcp" => return Ok(Self::Tcp),
|
|
||||||
_ => return Err(format!("\"{input}\" is not a valid mode")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Mode {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
let repr: String = match self {
|
|
||||||
Self::Tcp => format!("tcp"),
|
|
||||||
Self::Tls => format!("tls"),
|
|
||||||
};
|
|
||||||
write!(f, "{}", repr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub addr: std::net::SocketAddr,
|
pub addr: std::net::SocketAddr,
|
||||||
pub mode: Mode,
|
|
||||||
pub threads: usize,
|
|
||||||
pub timeout: Duration,
|
pub timeout: Duration,
|
||||||
pub delay: Duration,
|
pub delay: Duration,
|
||||||
pub win_after: usize,
|
pub win_after: usize,
|
||||||
|
pub key: PathBuf,
|
||||||
|
pub certs: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn new(cli: &Cli) -> Self {
|
pub fn new(cli: &Cli) -> Self {
|
||||||
Config {
|
Config {
|
||||||
addr: cli.addr.clone(),
|
addr: cli.addr.clone(),
|
||||||
mode: cli.mode.clone(),
|
|
||||||
threads: cli.threads,
|
|
||||||
timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN),
|
timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN),
|
||||||
delay: Duration::from_millis(DEFAULT_DELAY_LEN),
|
delay: Duration::from_millis(DEFAULT_DELAY_LEN),
|
||||||
win_after: DEFAULT_WIN_AFTER,
|
win_after: DEFAULT_WIN_AFTER,
|
||||||
|
key: cli.key.clone(),
|
||||||
|
certs: cli.certs.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,39 +1,48 @@
|
||||||
#![cfg(feature = "server")]
|
#![cfg(feature = "server")]
|
||||||
use std::{
|
use std::{
|
||||||
ops::Add,
|
fs::File, net::SocketAddr, sync::{atomic::AtomicUsize, Arc}, time::Duration
|
||||||
sync::{atomic::AtomicUsize, Arc},
|
|
||||||
time::Duration,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use libpt::log::{debug, info, trace, warn};
|
use libpt::log::{debug, info, trace, warn};
|
||||||
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||||
|
use rustls_pemfile::{certs, private_key};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
io::{split, AsyncReadExt, AsyncWriteExt, BufReader}, net::{TcpListener, TcpStream}, time::{self, timeout}
|
||||||
net::{TcpListener, TcpStream},
|
|
||||||
time::{self, timeout},
|
|
||||||
};
|
};
|
||||||
|
use tokio_rustls::{rustls, TlsAcceptor};
|
||||||
|
|
||||||
use crate::common::conf::Config;
|
use crate::common::conf::Config;
|
||||||
|
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
use errors::*;
|
use errors::*;
|
||||||
|
|
||||||
|
const BUF_SIZE: usize = 64;
|
||||||
|
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
pub timeout: Option<Duration>,
|
pub timeout: Option<Duration>,
|
||||||
server: TcpListener,
|
server: TcpListener,
|
||||||
num_peers: AtomicUsize,
|
num_peers: AtomicUsize,
|
||||||
|
acceptor: TlsAcceptor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Server {
|
impl Server {
|
||||||
pub async fn build(cfg: Config) -> anyhow::Result<Self> {
|
pub async fn build(cfg: Config) -> anyhow::Result<Self> {
|
||||||
|
let certs = Self::load_certs(cfg.clone())?;
|
||||||
|
let key = Self::load_key(cfg.clone())?.unwrap();
|
||||||
|
let tls_config = rustls::ServerConfig::builder()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(certs, key)?;
|
||||||
|
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
|
||||||
let server = TcpListener::bind(cfg.addr).await?;
|
let server = TcpListener::bind(cfg.addr).await?;
|
||||||
let timeout = Some(Duration::from_secs(5));
|
let timeout = Some(Duration::from_secs(5));
|
||||||
let num_peers = AtomicUsize::new(0);
|
|
||||||
Ok(Server {
|
Ok(Server {
|
||||||
cfg,
|
cfg,
|
||||||
timeout,
|
timeout,
|
||||||
server,
|
server,
|
||||||
num_peers,
|
num_peers: AtomicUsize::new(0),
|
||||||
|
acceptor,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
pub async fn run(self) -> anyhow::Result<()> {
|
pub async fn run(self) -> anyhow::Result<()> {
|
||||||
|
@ -52,20 +61,22 @@ impl Server {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
loop {
|
loop {
|
||||||
let ref_self = rc_self.clone();
|
let (stream, addr) = match rc_self.server.accept().await {
|
||||||
let (stream, addr) = match ref_self.server.accept().await {
|
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("could not accept stream: {err:?}");
|
warn!("could not accept stream: {err:?}");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let ref_self = rc_self.clone();
|
||||||
|
let acceptor = rc_self.acceptor.clone();
|
||||||
// NOTE: we can only start the task now. If we start it before accepting connections
|
// NOTE: we can only start the task now. If we start it before accepting connections
|
||||||
// (so that the task theoretically accepts the connection), we would create endless
|
// (so that the task theoretically accepts the connection), we would create endless
|
||||||
// tasks in a loop
|
// tasks in a loop.
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
let stream: tokio_rustls::server::TlsStream<_> = acceptor.accept(stream).await.unwrap();
|
||||||
ref_self.peer_add(1);
|
ref_self.peer_add(1);
|
||||||
match ref_self.handle_stream(stream).await {
|
match ref_self.handle_stream(stream, addr).await {
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
Err(err) => match err {
|
Err(err) => match err {
|
||||||
ServerError::Timeout(_) => {
|
ServerError::Timeout(_) => {
|
||||||
|
@ -81,10 +92,18 @@ impl Server {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_key(cfg: Config) -> std::io::Result<Option<PrivateKeyDer<'static>>> {
|
||||||
|
private_key(&mut std::io::BufReader::new(File::open(cfg.key)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_certs(cfg: Config) -> std::io::Result<Vec<CertificateDer<'static>>> {
|
||||||
|
certs(&mut std::io::BufReader::new(File::open(cfg.key)?)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn peer_add(&self, v: usize) {
|
fn peer_add(&self, v: usize) {
|
||||||
self.num_peers.store(
|
self.num_peers.store(
|
||||||
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v),
|
self.num_peers.load(std::sync::atomic::Ordering::Relaxed) + v,
|
||||||
std::sync::atomic::Ordering::Relaxed,
|
std::sync::atomic::Ordering::Relaxed,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -92,56 +111,26 @@ impl Server {
|
||||||
#[inline]
|
#[inline]
|
||||||
fn peer_sub(&self, v: usize) {
|
fn peer_sub(&self, v: usize) {
|
||||||
self.num_peers.store(
|
self.num_peers.store(
|
||||||
(self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v),
|
self.num_peers.load(std::sync::atomic::Ordering::Relaxed) - v,
|
||||||
std::sync::atomic::Ordering::Relaxed,
|
std::sync::atomic::Ordering::Relaxed,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_stream(&self, stream: TcpStream) -> Result<()> {
|
async fn handle_stream(&self, stream: tokio_rustls::server::TlsStream<TcpStream>, addr: SocketAddr) -> Result<()> {
|
||||||
let mut pings: usize = 0;
|
let mut pings: usize = 0;
|
||||||
let addr = match stream.peer_addr() {
|
|
||||||
Ok(a) => a,
|
|
||||||
Err(err) => {
|
|
||||||
debug!("could not get peer address: {:?}", err);
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
debug!("new peer: {:?}", addr);
|
debug!("new peer: {:?}", addr);
|
||||||
let mut buf = Vec::new();
|
let mut buf = [0; BUF_SIZE];
|
||||||
let mut reader = BufReader::new(stream);
|
let (mut reader, mut writer) = split(stream);
|
||||||
loop {
|
loop {
|
||||||
match self.read(&mut reader, &mut buf).await {
|
match reader.read(&mut buf).await {
|
||||||
Ok(len) if len == 0 => {
|
Ok(len) if len == 0 => { break;},
|
||||||
trace!("len is 0, so the stream has ended: {len:?}");
|
Ok(_) => (),
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(len) => len,
|
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
match err {
|
eprintln!("reader.read err: {err}")
|
||||||
ServerError::Timeout(_) => {
|
|
||||||
debug!("peer {:?} timed out", addr)
|
|
||||||
}
|
}
|
||||||
_ => return Err(err),
|
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
}
|
writer.write(b"pong\0").await?;
|
||||||
};
|
|
||||||
trace!("received message: {:X?}", buf);
|
|
||||||
let msg = self.decode(&buf)?;
|
|
||||||
debug!("< {:?} : {}", addr, msg);
|
|
||||||
if msg.contains("ping") {
|
|
||||||
pings += 1;
|
|
||||||
}
|
|
||||||
if pings < self.cfg.win_after {
|
|
||||||
reader.write_all(b"pong\0").await?;
|
|
||||||
debug!("> {:?} : pong", addr,);
|
|
||||||
} else {
|
|
||||||
reader.write_all(b"you win!\0").await?;
|
|
||||||
debug!("> {:?} : you win!", addr,);
|
|
||||||
reader.shutdown().await?;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
buf.clear();
|
|
||||||
|
|
||||||
// we should wait, so that we don't spam the client
|
// we should wait, so that we don't spam the client
|
||||||
std::thread::sleep(self.cfg.delay);
|
std::thread::sleep(self.cfg.delay);
|
||||||
|
@ -154,12 +143,4 @@ impl Server {
|
||||||
fn decode(&self, buf: &Vec<u8>) -> Result<String> {
|
fn decode(&self, buf: &Vec<u8>) -> Result<String> {
|
||||||
Ok(String::from_utf8(buf.clone())?.replace('\n', "\\n"))
|
Ok(String::from_utf8(buf.clone())?.replace('\n', "\\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
async fn read(&self, reader: &mut BufReader<TcpStream>, buf: &mut Vec<u8>) -> Result<usize> {
|
|
||||||
match timeout(self.cfg.timeout, reader.read_until(0x00, buf)).await? {
|
|
||||||
Ok(len) => Ok(len),
|
|
||||||
Err(err) => Err(err.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue