diff --git a/Cargo.toml b/Cargo.toml index 05cc371..c9a1733 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ thiserror = "1.0.56" tokio = { version = "1.35.1", features = ["net", "rt", "macros"] } rustls-pemfile = "2.0.0" tokio-rustls = "0.25.0" +webpki-roots = "0.26.0" [features] default = ["server"] diff --git a/src/client/mod.rs b/src/client/mod.rs index 8b13789..baa8d5f 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1 +1,74 @@ +#![cfg(feature = "server")] +use std::{fs::File, io::BufReader, sync::Arc}; +use crate::{common::decode, Config}; + +use anyhow; +use libpt::log::{error, info}; +use rustls_pemfile::certs; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, +}; +use tokio_rustls::{ + rustls::{self, pki_types}, + TlsConnector, +}; +use webpki_roots; + +const BUF_SIZE: usize = 512; + +pub struct Client { + cfg: Config, + stream: TcpStream, + connector: TlsConnector, + domain: pki_types::ServerName<'static>, +} + +impl Client { + pub async fn build(cfg: Config) -> anyhow::Result { + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + if cfg.certs.is_some() { + let mut reader = BufReader::new(File::open(cfg.certs.clone().unwrap())?); + for cert in certs(&mut reader) { + root_cert_store.add(cert?)? + } + } + let tls_config = rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(tls_config)); + let stream = TcpStream::connect(&cfg.addr).await?; + let domain = match pki_types::ServerName::try_from(cfg.hostname.clone()) { + Ok(domain) => domain, + Err(err) => { + error!("Could not resolve hostname '{}': {err:?}", cfg.hostname); + return Err(err.into()); + }, + }; + + Ok(Client { + cfg: cfg.clone(), + stream, + connector, + domain, + }) + } + + pub async fn run(self) -> anyhow::Result<()> { + let mut stream = self.connector.connect(self.domain, self.stream).await?; + let mut buf = [0; BUF_SIZE]; + stream.write_all(b"ping").await?; + info!("> {:?} ping", self.cfg.hostname); + while stream.read(&mut buf).await? > 0 { + let response = decode(&buf)?; + info!("< {:?}\n{}", self.cfg.hostname, response); + stream.write_all(b"ping").await?; + info!("> {:?} ping", self.cfg.hostname); + // we should wait, so that we don't spam the client + std::thread::sleep(self.cfg.delay); + } + Ok(()) + } +} diff --git a/src/common/args.rs b/src/common/args.rs index 970d848..3d161f8 100644 --- a/src/common/args.rs +++ b/src/common/args.rs @@ -47,14 +47,13 @@ pub(crate) struct Cli { pub(crate) server: bool, /// Address of the server - pub(crate) addr: std::net::SocketAddr, + pub(crate) host: String, #[cfg(feature = "server")] #[arg(short, long)] - pub key: PathBuf, - #[cfg(feature = "server")] + pub key: Option, #[arg(short, long)] - pub certs: PathBuf, + pub certs: Option, } impl Cli { diff --git a/src/common/conf.rs b/src/common/conf.rs index 5ff59b9..f80bd52 100644 --- a/src/common/conf.rs +++ b/src/common/conf.rs @@ -1,6 +1,13 @@ use crate::common::args::Cli; -use std::{path::PathBuf, time::Duration}; +use std::{ + net::{SocketAddr, ToSocketAddrs}, + path::PathBuf, + time::Duration, +}; + +use anyhow::{anyhow, Result}; +use libpt::log::{error, trace}; const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms const DEFAULT_DELAY_LEN: u64 = 500; // ms @@ -9,22 +16,46 @@ const DEFAULT_WIN_AFTER: usize = 20; #[derive(Clone)] pub struct Config { pub addr: std::net::SocketAddr, + pub hostname: String, pub timeout: Duration, pub delay: Duration, + #[cfg(feature = "server")] pub win_after: usize, - pub key: PathBuf, - pub certs: PathBuf, + #[cfg(feature = "server")] + pub key: Option, + pub certs: Option, } impl Config { - pub fn new(cli: &Cli) -> Self { - Config { - addr: cli.addr.clone(), + pub fn build(cli: &Cli) -> Result { + let addr: SocketAddr = match cli.host.to_socket_addrs() { + Ok(mut addr) => addr.next().unwrap(), + Err(err) => { + error!( + "could not resolve host {:?} to a socket address: {err:?}", + cli.host.clone() + ); + return Err(anyhow!( + "could not resolve host {:?} to a socket address: {err:?}", + cli.host + )); + } + }; + let hostname = match cli.host.split_once(':') { + Some(hostname) => hostname.0.to_string(), + None => return Err(anyhow!("malformatted host (no port specified)")), + }; + trace!("config has resolved the given hostname to: {addr:?}"); + Ok(Config { + addr, + hostname, timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN), delay: Duration::from_millis(DEFAULT_DELAY_LEN), + #[cfg(feature = "server")] win_after: DEFAULT_WIN_AFTER, + #[cfg(feature = "server")] key: cli.key.clone(), certs: cli.certs.clone(), - } + }) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index f4200ec..9ed27e4 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,2 +1,12 @@ +use std::str::Utf8Error; + pub mod args; pub mod conf; + +#[inline] +pub fn decode(buf: &[u8]) -> Result { + match std::str::from_utf8(buf) { + Ok(s) => Ok(s.to_string()), + Err(err) => Err(err.into()), + } +} diff --git a/src/main.rs b/src/main.rs index 952dfa1..5c73f3c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,14 +13,14 @@ mod server; use common::{args::Cli, conf::*}; -use crate::server::Server; +use crate::{client::Client, server::Server}; #[tokio::main] async fn main() -> Result<()> { let cli = Cli::cli_parse(); debug!("dumping cli args:\n{:#?}", cli); - let cfg = Config::new(&cli); + let cfg = Config::build(&cli)?; #[cfg(feature = "server")] if cli.server { @@ -29,6 +29,5 @@ async fn main() -> Result<()> { } // implicit else, so we can work without the server feature info!("starting client"); - todo!(); - loop {} + return Client::build(cfg).await?.run().await; } diff --git a/src/server/mod.rs b/src/server/mod.rs index 7fc22ec..5ebf4d3 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -10,13 +10,13 @@ use libpt::log::{debug, error, info, trace, warn}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls_pemfile::{certs, private_key}; use tokio::{ - io::{copy, sink, AsyncReadExt, AsyncWriteExt}, + io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, time::{self}, }; use tokio_rustls::{rustls, TlsAcceptor}; -use crate::common::conf::Config; +use crate::common::{conf::Config, decode}; pub mod errors; use errors::*; @@ -25,7 +25,6 @@ const BUF_SIZE: usize = 512; pub struct Server { cfg: Config, - pub timeout: Option, server: TcpListener, num_peers: AtomicUsize, acceptor: TlsAcceptor, @@ -42,11 +41,9 @@ impl Server { .with_single_cert(certs, key)?; let acceptor = TlsAcceptor::from(Arc::new(tls_config)); let server = TcpListener::bind(cfg.addr).await?; - let timeout = Some(Duration::from_secs(5)); Ok(Server { cfg, - timeout, server, num_peers: AtomicUsize::new(0), acceptor, @@ -107,13 +104,25 @@ impl Server { } fn load_key(cfg: Config) -> std::io::Result>> { - let key = private_key(&mut std::io::BufReader::new(File::open(cfg.key)?)); + if cfg.key.is_none() { + error!("the server needs a key!"); + return Err(std::io::ErrorKind::InvalidInput.into()); + } + let key = private_key(&mut std::io::BufReader::new(File::open( + cfg.key.clone().unwrap(), + )?)); return key; } fn load_certs(cfg: Config) -> std::io::Result>> { - match certs(&mut std::io::BufReader::new(File::open(&cfg.certs)?)) - .collect::>>>() + if cfg.certs.is_none() { + error!("the server needs at least one certificate!"); + return Err(std::io::ErrorKind::InvalidInput.into()); + } + match certs(&mut std::io::BufReader::new(File::open( + &cfg.certs.clone().unwrap(), + )?)) + .collect::>>>() { Ok(v) if !v.is_empty() => Ok(v), Ok(_) => { @@ -152,19 +161,21 @@ impl Server { let mut pings = 0; debug!("new peer: {:?}", addr); while stream.read(&mut buf).await? > 0 { - let request = self.decode(&buf)?; + let request = decode(&buf)?; debug!("< {:?}\n{}", addr, request); if request == "ping" { pings += 1; - } - if pings > 20 { - stream.write_all(b"You win!").await?; + if pings > self.cfg.win_after { + stream.write_all(b"You win!").await?; + stream.flush().await?; + stream.shutdown().await?; + break; + } + stream.write_all(b"pong").await?; + } else { + stream.write_all(b"what is the magic word?").await?; stream.flush().await?; - stream.shutdown().await?; - break; } - stream.write_all(b"pong").await?; - stream.flush().await?; // we should wait, so that we don't spam the client std::thread::sleep(self.cfg.delay); } @@ -172,11 +183,4 @@ impl Server { Ok(()) } - #[inline] - fn decode(&self, buf: &[u8]) -> Result { - match std::str::from_utf8(buf) { - Ok(s) => Ok(s.to_string()), - Err(err) => Err(err.into()), - } - } }