mainly trying to get the client to work
cargo devel CI / cargo CI (push) Successful in 2m27s Details

This commit is contained in:
Christoph J. Scherr 2024-01-24 13:22:16 +01:00
parent 4595f37f4c
commit cc80d0afae
Signed by: cscherrNT
GPG Key ID: 8E2B45BC51A27EA7
7 changed files with 155 additions and 38 deletions

View File

@ -23,6 +23,7 @@ thiserror = "1.0.56"
tokio = { version = "1.35.1", features = ["net", "rt", "macros"] } tokio = { version = "1.35.1", features = ["net", "rt", "macros"] }
rustls-pemfile = "2.0.0" rustls-pemfile = "2.0.0"
tokio-rustls = "0.25.0" tokio-rustls = "0.25.0"
webpki-roots = "0.26.0"
[features] [features]
default = ["server"] default = ["server"]

View File

@ -1 +1,74 @@
#![cfg(feature = "server")]
use std::{fs::File, io::BufReader, sync::Arc};
use crate::{common::decode, Config};
use anyhow;
use libpt::log::{error, info};
use rustls_pemfile::certs;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tokio_rustls::{
rustls::{self, pki_types},
TlsConnector,
};
use webpki_roots;
const BUF_SIZE: usize = 512;
pub struct Client {
cfg: Config,
stream: TcpStream,
connector: TlsConnector,
domain: pki_types::ServerName<'static>,
}
impl Client {
pub async fn build(cfg: Config) -> anyhow::Result<Self> {
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if cfg.certs.is_some() {
let mut reader = BufReader::new(File::open(cfg.certs.clone().unwrap())?);
for cert in certs(&mut reader) {
root_cert_store.add(cert?)?
}
}
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(tls_config));
let stream = TcpStream::connect(&cfg.addr).await?;
let domain = match pki_types::ServerName::try_from(cfg.hostname.clone()) {
Ok(domain) => domain,
Err(err) => {
error!("Could not resolve hostname '{}': {err:?}", cfg.hostname);
return Err(err.into());
},
};
Ok(Client {
cfg: cfg.clone(),
stream,
connector,
domain,
})
}
pub async fn run(self) -> anyhow::Result<()> {
let mut stream = self.connector.connect(self.domain, self.stream).await?;
let mut buf = [0; BUF_SIZE];
stream.write_all(b"ping").await?;
info!("> {:?} ping", self.cfg.hostname);
while stream.read(&mut buf).await? > 0 {
let response = decode(&buf)?;
info!("< {:?}\n{}", self.cfg.hostname, response);
stream.write_all(b"ping").await?;
info!("> {:?} ping", self.cfg.hostname);
// we should wait, so that we don't spam the client
std::thread::sleep(self.cfg.delay);
}
Ok(())
}
}

View File

@ -47,14 +47,13 @@ pub(crate) struct Cli {
pub(crate) server: bool, pub(crate) server: bool,
/// Address of the server /// Address of the server
pub(crate) addr: std::net::SocketAddr, pub(crate) host: String,
#[cfg(feature = "server")] #[cfg(feature = "server")]
#[arg(short, long)] #[arg(short, long)]
pub key: PathBuf, pub key: Option<PathBuf>,
#[cfg(feature = "server")]
#[arg(short, long)] #[arg(short, long)]
pub certs: PathBuf, pub certs: Option<PathBuf>,
} }
impl Cli { impl Cli {

View File

@ -1,6 +1,13 @@
use crate::common::args::Cli; use crate::common::args::Cli;
use std::{path::PathBuf, time::Duration}; use std::{
net::{SocketAddr, ToSocketAddrs},
path::PathBuf,
time::Duration,
};
use anyhow::{anyhow, Result};
use libpt::log::{error, trace};
const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms const DEFAULT_TIMEOUT_LEN: u64 = 5000; // ms
const DEFAULT_DELAY_LEN: u64 = 500; // ms const DEFAULT_DELAY_LEN: u64 = 500; // ms
@ -9,22 +16,46 @@ const DEFAULT_WIN_AFTER: usize = 20;
#[derive(Clone)] #[derive(Clone)]
pub struct Config { pub struct Config {
pub addr: std::net::SocketAddr, pub addr: std::net::SocketAddr,
pub hostname: String,
pub timeout: Duration, pub timeout: Duration,
pub delay: Duration, pub delay: Duration,
#[cfg(feature = "server")]
pub win_after: usize, pub win_after: usize,
pub key: PathBuf, #[cfg(feature = "server")]
pub certs: PathBuf, pub key: Option<PathBuf>,
pub certs: Option<PathBuf>,
} }
impl Config { impl Config {
pub fn new(cli: &Cli) -> Self { pub fn build(cli: &Cli) -> Result<Self> {
Config { let addr: SocketAddr = match cli.host.to_socket_addrs() {
addr: cli.addr.clone(), Ok(mut addr) => addr.next().unwrap(),
Err(err) => {
error!(
"could not resolve host {:?} to a socket address: {err:?}",
cli.host.clone()
);
return Err(anyhow!(
"could not resolve host {:?} to a socket address: {err:?}",
cli.host
));
}
};
let hostname = match cli.host.split_once(':') {
Some(hostname) => hostname.0.to_string(),
None => return Err(anyhow!("malformatted host (no port specified)")),
};
trace!("config has resolved the given hostname to: {addr:?}");
Ok(Config {
addr,
hostname,
timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN), timeout: Duration::from_millis(DEFAULT_TIMEOUT_LEN),
delay: Duration::from_millis(DEFAULT_DELAY_LEN), delay: Duration::from_millis(DEFAULT_DELAY_LEN),
#[cfg(feature = "server")]
win_after: DEFAULT_WIN_AFTER, win_after: DEFAULT_WIN_AFTER,
#[cfg(feature = "server")]
key: cli.key.clone(), key: cli.key.clone(),
certs: cli.certs.clone(), certs: cli.certs.clone(),
} })
} }
} }

View File

@ -1,2 +1,12 @@
use std::str::Utf8Error;
pub mod args; pub mod args;
pub mod conf; pub mod conf;
#[inline]
pub fn decode(buf: &[u8]) -> Result<String, Utf8Error> {
match std::str::from_utf8(buf) {
Ok(s) => Ok(s.to_string()),
Err(err) => Err(err.into()),
}
}

View File

@ -13,14 +13,14 @@ mod server;
use common::{args::Cli, conf::*}; use common::{args::Cli, conf::*};
use crate::server::Server; use crate::{client::Client, server::Server};
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let cli = Cli::cli_parse(); let cli = Cli::cli_parse();
debug!("dumping cli args:\n{:#?}", cli); debug!("dumping cli args:\n{:#?}", cli);
let cfg = Config::new(&cli); let cfg = Config::build(&cli)?;
#[cfg(feature = "server")] #[cfg(feature = "server")]
if cli.server { if cli.server {
@ -29,6 +29,5 @@ async fn main() -> Result<()> {
} }
// implicit else, so we can work without the server feature // implicit else, so we can work without the server feature
info!("starting client"); info!("starting client");
todo!(); return Client::build(cfg).await?.run().await;
loop {}
} }

View File

@ -10,13 +10,13 @@ use libpt::log::{debug, error, info, trace, warn};
use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls_pemfile::{certs, private_key}; use rustls_pemfile::{certs, private_key};
use tokio::{ use tokio::{
io::{copy, sink, AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
time::{self}, time::{self},
}; };
use tokio_rustls::{rustls, TlsAcceptor}; use tokio_rustls::{rustls, TlsAcceptor};
use crate::common::conf::Config; use crate::common::{conf::Config, decode};
pub mod errors; pub mod errors;
use errors::*; use errors::*;
@ -25,7 +25,6 @@ const BUF_SIZE: usize = 512;
pub struct Server { pub struct Server {
cfg: Config, cfg: Config,
pub timeout: Option<Duration>,
server: TcpListener, server: TcpListener,
num_peers: AtomicUsize, num_peers: AtomicUsize,
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
@ -42,11 +41,9 @@ impl Server {
.with_single_cert(certs, key)?; .with_single_cert(certs, key)?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config)); let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let server = TcpListener::bind(cfg.addr).await?; let server = TcpListener::bind(cfg.addr).await?;
let timeout = Some(Duration::from_secs(5));
Ok(Server { Ok(Server {
cfg, cfg,
timeout,
server, server,
num_peers: AtomicUsize::new(0), num_peers: AtomicUsize::new(0),
acceptor, acceptor,
@ -107,13 +104,25 @@ impl Server {
} }
fn load_key(cfg: Config) -> std::io::Result<Option<PrivateKeyDer<'static>>> { fn load_key(cfg: Config) -> std::io::Result<Option<PrivateKeyDer<'static>>> {
let key = private_key(&mut std::io::BufReader::new(File::open(cfg.key)?)); if cfg.key.is_none() {
error!("the server needs a key!");
return Err(std::io::ErrorKind::InvalidInput.into());
}
let key = private_key(&mut std::io::BufReader::new(File::open(
cfg.key.clone().unwrap(),
)?));
return key; return key;
} }
fn load_certs(cfg: Config) -> std::io::Result<Vec<CertificateDer<'static>>> { fn load_certs(cfg: Config) -> std::io::Result<Vec<CertificateDer<'static>>> {
match certs(&mut std::io::BufReader::new(File::open(&cfg.certs)?)) if cfg.certs.is_none() {
.collect::<std::io::Result<Vec<CertificateDer<'static>>>>() error!("the server needs at least one certificate!");
return Err(std::io::ErrorKind::InvalidInput.into());
}
match certs(&mut std::io::BufReader::new(File::open(
&cfg.certs.clone().unwrap(),
)?))
.collect::<std::io::Result<Vec<CertificateDer<'static>>>>()
{ {
Ok(v) if !v.is_empty() => Ok(v), Ok(v) if !v.is_empty() => Ok(v),
Ok(_) => { Ok(_) => {
@ -152,19 +161,21 @@ impl Server {
let mut pings = 0; let mut pings = 0;
debug!("new peer: {:?}", addr); debug!("new peer: {:?}", addr);
while stream.read(&mut buf).await? > 0 { while stream.read(&mut buf).await? > 0 {
let request = self.decode(&buf)?; let request = decode(&buf)?;
debug!("< {:?}\n{}", addr, request); debug!("< {:?}\n{}", addr, request);
if request == "ping" { if request == "ping" {
pings += 1; pings += 1;
} if pings > self.cfg.win_after {
if pings > 20 { stream.write_all(b"You win!").await?;
stream.write_all(b"You win!").await?; stream.flush().await?;
stream.shutdown().await?;
break;
}
stream.write_all(b"pong").await?;
} else {
stream.write_all(b"what is the magic word?").await?;
stream.flush().await?; stream.flush().await?;
stream.shutdown().await?;
break;
} }
stream.write_all(b"pong").await?;
stream.flush().await?;
// we should wait, so that we don't spam the client // we should wait, so that we don't spam the client
std::thread::sleep(self.cfg.delay); std::thread::sleep(self.cfg.delay);
} }
@ -172,11 +183,4 @@ impl Server {
Ok(()) Ok(())
} }
#[inline]
fn decode(&self, buf: &[u8]) -> Result<String> {
match std::str::from_utf8(buf) {
Ok(s) => Ok(s.to_string()),
Err(err) => Err(err.into()),
}
}
} }