generated from PlexSheep/rs-base
76 lines
2.4 KiB
Rust
76 lines
2.4 KiB
Rust
#![cfg(feature = "server")]
|
|
use std::{fs::File, io::BufReader, sync::Arc};
|
|
|
|
use crate::{common::decode, Config};
|
|
|
|
use anyhow;
|
|
use libpt::log::{error, info, trace};
|
|
use rustls_pemfile::certs;
|
|
use tokio::{
|
|
io::{AsyncReadExt, AsyncWriteExt},
|
|
net::TcpStream,
|
|
};
|
|
use tokio_rustls::{
|
|
rustls::{self, pki_types},
|
|
TlsConnector,
|
|
};
|
|
use webpki_roots;
|
|
|
|
const BUF_SIZE: usize = 512;
|
|
|
|
pub struct Client {
|
|
cfg: Config,
|
|
stream: TcpStream,
|
|
connector: TlsConnector,
|
|
domain: pki_types::ServerName<'static>,
|
|
}
|
|
|
|
impl Client {
|
|
pub async fn build(cfg: Config) -> anyhow::Result<Self> {
|
|
let mut root_cert_store = rustls::RootCertStore::empty();
|
|
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
|
if cfg.certs.is_some() {
|
|
let mut reader = BufReader::new(File::open(cfg.certs.clone().unwrap())?);
|
|
for cert in certs(&mut reader) {
|
|
trace!("found custom cert: {cert:?}");
|
|
root_cert_store.add(cert?)?
|
|
}
|
|
}
|
|
trace!("root cert store: {root_cert_store:?}");
|
|
let tls_config = rustls::ClientConfig::builder()
|
|
.with_root_certificates(root_cert_store)
|
|
.with_no_client_auth();
|
|
let connector = TlsConnector::from(Arc::new(tls_config));
|
|
let stream = TcpStream::connect(&cfg.addr).await?;
|
|
let domain = match pki_types::ServerName::try_from(cfg.hostname.clone()) {
|
|
Ok(domain) => domain,
|
|
Err(err) => {
|
|
error!("Could not resolve hostname '{}': {err:?}", cfg.hostname);
|
|
return Err(err.into());
|
|
}
|
|
};
|
|
|
|
Ok(Client {
|
|
cfg: cfg.clone(),
|
|
stream,
|
|
connector,
|
|
domain,
|
|
})
|
|
}
|
|
|
|
pub async fn run(self) -> anyhow::Result<()> {
|
|
let mut stream = self.connector.connect(self.domain, self.stream).await?;
|
|
let mut buf = [0; BUF_SIZE];
|
|
stream.write_all(b"ping").await?;
|
|
info!("> {:?} ping", self.cfg.hostname);
|
|
while stream.read(&mut buf).await? > 0 {
|
|
let response = decode(&buf)?;
|
|
info!("< {:?}\n{}", self.cfg.hostname, response);
|
|
stream.write_all(b"ping").await?;
|
|
info!("> {:?} ping", self.cfg.hostname);
|
|
// we should wait, so that we don't spam the client
|
|
std::thread::sleep(self.cfg.delay);
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|