diff --git a/src/client/mod.rs b/src/client/mod.rs index 10fc9c4..00a69e9 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -11,8 +11,7 @@ use tokio::{ net::TcpStream, }; use tokio_rustls::{ - rustls::{self, pki_types}, - TlsConnector, + rustls::{self, pki_types}, TlsConnector, TlsStream }; use webpki_roots; @@ -20,7 +19,7 @@ const BUF_SIZE: usize = 512; pub struct Client { cfg: Config, - stream: TcpStream, + stream: TlsStream, connector: TlsConnector, domain: pki_types::ServerName<'static>, } @@ -41,7 +40,6 @@ impl Client { .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(tls_config)); - let stream = TcpStream::connect(&cfg.addr).await?; let domain = match pki_types::ServerName::try_from(cfg.hostname.clone()) { Ok(domain) => domain, Err(err) => { @@ -49,6 +47,7 @@ impl Client { return Err(err.into()); } }; + let mut stream = connector.connect(domain, TcpStream::connect(&cfg.addr).await?).await?; Ok(Client { cfg: cfg.clone(), @@ -58,18 +57,17 @@ impl Client { }) } - pub async fn run(self) -> anyhow::Result<()> { - let mut stream = self.connector.connect(self.domain, self.stream).await?; + pub async fn run(mut self) -> anyhow::Result<()> { let mut buf = [0; BUF_SIZE]; - stream.write_all(b"ping").await?; + self.stream.write_all(b"ping").await?; info!("> ({}) ping", self.cfg.hostname); - while stream.read(&mut buf).await? > 0 { + while self.stream.read(&mut buf).await? > 0 { let response = decode(&buf)?; info!("< ({}) {}", self.cfg.hostname, response); if response == "You win!" { break; } - stream.write_all(b"ping").await?; + self.stream.write_all(b"ping").await?; info!("> ({}) ping", self.cfg.hostname); // we should wait, so that we don't spam the client std::thread::sleep(self.cfg.delay);