diff --git a/Cargo.toml b/Cargo.toml index 7adea65..ea0bcb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ clap = { version = "4.5.3", features = ["derive"], optional = true } colored = { version = "2.1.0", optional = false } libpt = "0.4.2" rand = "0.8.5" +rayon = "1.10.0" regex = "1.10.3" serde = { version = "1.0.197", optional = true, features = ["serde_derive"] } serde_json = { version = "1.0.114", optional = true } diff --git a/src/bench/mod.rs b/src/bench/mod.rs index 0cace90..d25eca6 100644 --- a/src/bench/mod.rs +++ b/src/bench/mod.rs @@ -1,6 +1,8 @@ -use std::fmt::{Debug, Display}; +use std::fmt::Debug; +use std::sync::{Arc, Mutex}; use libpt::log::debug; +use rayon::prelude::*; use crate::error::WResult; use crate::game::response::GuessResponse; @@ -17,7 +19,7 @@ pub mod builtin; /// Default amount of games to play for a [Benchmark] pub const DEFAULT_N: usize = 50; -pub trait Benchmark<'wl, WL, SL>: Clone + Sized + Debug +pub trait Benchmark<'wl, WL, SL>: Clone + Sized + Debug + Sync where WL: WordList, WL: 'wl, @@ -40,21 +42,26 @@ where // TODO: add some interface to get reports while the benchmark runs // TODO: make the benchmark optionally multithreaded fn bench(&'wl self, n: usize) -> WResult { - // PERF: it would be better to make this multithreaded let part = match n / 20 { 0 => 19, other => other, }; - let mut report = Report::new(); + let report = Arc::new(Mutex::new(Report::new())); + let this = std::sync::Arc::new(self); - for i in 0..n { - report.add(self.play()?); - if i % part == part - 1 { - // TODO: add the report to the struct so that users can poll it to print the status - // TODO: update the report in the struct - } - } + (0..n) + .into_par_iter() + .for_each_with(report.clone(), |outside_data, _i| { + let report = outside_data; + let r = this + .play() + .expect("error playing the game during benchmark"); + report.lock().expect("lock is poisoned").add(r); + }); + // FIXME: find some way to take the Report from the Mutex + // Mutex::into_inner() does not work + let mut report: Report = report.lock().unwrap().clone(); report.finalize(); Ok(report) diff --git a/src/error.rs b/src/error.rs index 75b56e4..b5731b9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,6 @@ use thiserror::Error; +use crate::bench::report::Report; use crate::wlist::word::Word; pub type WResult = std::result::Result; @@ -25,6 +26,11 @@ pub enum Error { #[from] source: regex::Error, }, + #[error("Error sharing the benchmark data over multiple threads")] + Mutex { + #[from] + source: std::sync::PoisonError + } } #[derive(Debug, Clone, Error)] diff --git a/src/solve/mod.rs b/src/solve/mod.rs index 1615dc1..800f8a4 100644 --- a/src/solve/mod.rs +++ b/src/solve/mod.rs @@ -37,7 +37,7 @@ pub use stupid::StupidSolver; /// /// If you want to have the user select a model, create an enum with it's variants containing your /// [Solvers][Solver] and have this enum implement [Solver], see [AnyBuiltinSolver]. -pub trait Solver<'wl, WL: WordList>: Clone + std::fmt::Debug + Sized { +pub trait Solver<'wl, WL: WordList>: Clone + std::fmt::Debug + Sized + Sync { /// Build and initialize a [Solver] fn build(wordlist: &'wl WL) -> WResult; /// Calculate the next guess for a [Game] diff --git a/src/wlist/mod.rs b/src/wlist/mod.rs index ec7db95..6ed8cd1 100644 --- a/src/wlist/mod.rs +++ b/src/wlist/mod.rs @@ -15,7 +15,7 @@ use crate::error::WResult; pub type AnyWordlist = Box; -pub trait WordList: Clone + std::fmt::Debug + Default { +pub trait WordList: Clone + std::fmt::Debug + Default + Sync { fn solutions(&self) -> ManyWordDatas { let wmap = self.wordmap().clone(); let threshold = wmap.threshold();