diff --git a/data/test.pem b/data/key.pem similarity index 100% rename from data/test.pem rename to data/key.pem diff --git a/src/server/mod.rs b/src/server/mod.rs index 3cc528b..d1e7498 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,7 +6,8 @@ use std::{ time::Duration, }; -use libpt::log::{debug, error, info, warn}; +use anyhow::anyhow; +use libpt::log::{debug, error, info, trace, warn}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls_pemfile::{certs, private_key}; use tokio::{ @@ -34,7 +35,9 @@ pub struct Server { impl Server { pub async fn build(cfg: Config) -> anyhow::Result { let certs = Self::load_certs(cfg.clone())?; + trace!("loaded certs: {:?}", certs); let key = Self::load_key(cfg.clone())?.expect("bad key?"); + trace!("loaded key: {:?}", key); let tls_config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)?; @@ -50,6 +53,7 @@ impl Server { acceptor, }) } + pub async fn run(self) -> anyhow::Result<()> { let rc_self = Arc::new(self); let ref_self = rc_self.clone(); @@ -104,11 +108,24 @@ impl Server { } fn load_key(cfg: Config) -> std::io::Result>> { - private_key(&mut std::io::BufReader::new(File::open(cfg.key)?)) + let key = private_key(&mut std::io::BufReader::new(File::open(cfg.key)?)); + return key; } fn load_certs(cfg: Config) -> std::io::Result>> { - certs(&mut std::io::BufReader::new(File::open(cfg.key)?)).collect() + match certs(&mut std::io::BufReader::new(File::open(&cfg.certs)?)) + .collect::>>>() + { + Ok(v) if !v.is_empty() => Ok(v), + Ok(_) => { + error!("no certs found in provided file {:?}", cfg.certs); + return Err(std::io::ErrorKind::InvalidInput.into()); + } + Err(err) => { + error!("could not load certs: {err:?}"); + return Err(err); + } + } } #[inline] @@ -129,25 +146,13 @@ impl Server { async fn handle_stream( &self, - stream: tokio_rustls::server::TlsStream, + mut stream: tokio_rustls::server::TlsStream, addr: SocketAddr, ) -> Result<()> { debug!("new peer: {:?}", addr); let mut buf = [0; BUF_SIZE]; - let (mut reader, mut writer) = split(stream); - loop { - match reader.read(&mut buf).await { - Ok(len) if len == 0 => { - break; - } - Ok(_) => (), - Err(err) => { - eprintln!("reader.read err: {err}") - } - } - debug!("< {addr:?} : \"{}\"", self.decode(&buf)?); - - writer.write(b"pong\0").await?; + while stream.read(&mut buf).await? != 0 { + stream.write_all(b"pong\0"); // we should wait, so that we don't spam the client std::thread::sleep(self.cfg.delay);