From ca473824fe98e78b92d9fcf05999707efe8f1de3 Mon Sep 17 00:00:00 2001 From: "Christoph J. Scherr" Date: Thu, 21 Mar 2024 15:31:32 +0100 Subject: [PATCH] freq as struct --- data/wordlists/test.json | 3 ++ src/bin/game/cli.rs | 2 + src/game/mod.rs | 10 +++-- src/wlist/builtin.rs | 32 ++++++++++++---- src/wlist/mod.rs | 7 ++++ src/wlist/word.rs | 79 ++++++++++++++++++++++++++++++++++++++-- 6 files changed, 119 insertions(+), 14 deletions(-) create mode 100644 data/wordlists/test.json diff --git a/data/wordlists/test.json b/data/wordlists/test.json new file mode 100644 index 0000000..ce95e4f --- /dev/null +++ b/data/wordlists/test.json @@ -0,0 +1,3 @@ +{ + "word": 0.001 +} diff --git a/src/bin/game/cli.rs b/src/bin/game/cli.rs index c5e3ec3..9d64210 100644 --- a/src/bin/game/cli.rs +++ b/src/bin/game/cli.rs @@ -31,6 +31,8 @@ fn main() -> anyhow::Result<()> { .precompute(cli.precompute) .build()?; + debug!("game: {:#?}", game); + Ok(()) } diff --git a/src/game/mod.rs b/src/game/mod.rs index c3090be..58c4b07 100644 --- a/src/game/mod.rs +++ b/src/game/mod.rs @@ -1,3 +1,4 @@ +use crate::wlist::word::Word; use crate::wlist::WordList; #[derive(Debug, Clone, PartialEq, Eq)] @@ -9,7 +10,7 @@ where precompute: bool, max_steps: usize, step: usize, - solution: String, + solution: Word, wordlist: WL, } @@ -31,16 +32,17 @@ impl Game { /// /// This function will return an error if . pub(crate) fn build(length: usize, precompute: bool, max_steps: usize, wlist: WL) -> anyhow::Result { - let _game = Game { + let mut game = Game { length, precompute, max_steps, step: 0, - solution: String::default(), // we actually set this later + solution: Word::default(), // we actually set this later wordlist: wlist }; - todo!(); + game.solution = game.wordlist.rand_solution().into(); + Ok(game) } } diff --git a/src/wlist/builtin.rs b/src/wlist/builtin.rs index 56d149d..8a24326 100644 --- a/src/wlist/builtin.rs +++ b/src/wlist/builtin.rs @@ -1,30 +1,48 @@ +use std::fmt::{write, Debug}; + use serde_json; -use super::Word; +use super::{Word, WordList}; const RAW_WORDLIST_FILE: &str = include_str!("../../data/wordlists/en_US_3b1b_freq_map.json"); -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct BuiltinWList { - words: super::WordMap + words: super::WordMap, } impl super::WordList for BuiltinWList { fn solutions(&self) -> Vec<&Word> { // PERF: this can be made faster if we were to use parallel iterators or chunking - self.words.keys().collect() + // TODO: Filter should be a bit more elegant + let threshold = self.total_freq() / 2; + self.wordmap().iter().filter(|i| i.1 > ) } fn length_range(&self) -> impl std::ops::RangeBounds { 5..5 } + fn wordmap(&self) -> &super::WordMap { + &self.words + } } impl Default for BuiltinWList { fn default() -> Self { let words: super::WordMap = serde_json::from_str(RAW_WORDLIST_FILE).unwrap(); - Self { - words - } + Self { words } + } +} + +impl Debug for BuiltinWList { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write( + f, + format_args!( + "BuiltinWList {{ amount: {}, total_freq: {} }}", + self.amount(), + self.total_freq() + ), + ) } } diff --git a/src/wlist/mod.rs b/src/wlist/mod.rs index dab7484..86efc46 100644 --- a/src/wlist/mod.rs +++ b/src/wlist/mod.rs @@ -22,4 +22,11 @@ pub trait WordList: Clone + std::fmt::Debug + Default { self.solutions().iter().choose(&mut rng).unwrap() } fn length_range(&self) -> impl RangeBounds; + fn amount(&self) -> usize { + self.solutions().len() + } + fn wordmap(&self) -> &WordMap; + fn total_freq(&self) -> Frequency { + self.wordmap().values().map(|a| a.to_owned()).sum() + } } diff --git a/src/wlist/word.rs b/src/wlist/word.rs index f639dfc..6b65c9d 100644 --- a/src/wlist/word.rs +++ b/src/wlist/word.rs @@ -1,25 +1,98 @@ use std::collections::HashMap; +use std::fmt::{write, Display}; +use std::iter::Sum; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; // NOTE: We might need a different implementation for more precision +// NOTE: This struct requires a custom Serialize and Deserialize implementation #[derive(Clone, Debug, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Frequency { - inner: f64 + inner: f64, } + // PERF: Hash for String is probably a bottleneck pub type Word = String; #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct WordMap { - inner: HashMap + #[serde(flatten)] + inner: HashMap, } impl WordMap { pub fn keys(&self) -> std::collections::hash_map::Keys<'_, String, Frequency> { self.inner.keys() } + pub fn values(&self) -> std::collections::hash_map::Values<'_, String, Frequency> { + self.inner.values() + } + pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Frequency> { + self.inner.iter() + } +} + +// We need custom Serialize and Deserialize of Frequency, because they are only primitive types. +// Serde does not support serializing directly to and from primitives (such as floats) +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for Frequency { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct FrequencyVisitor; + impl<'v> serde::de::Visitor<'v> for FrequencyVisitor { + type Value = Frequency; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(formatter, "a floating-point number") + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + Ok(Frequency { inner: v }) + } + } + + deserializer.deserialize_any(FrequencyVisitor) + } +} +#[cfg(feature = "serde")] +impl Serialize for Frequency { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_f64(self.inner) + } +} + +impl From for f64 { + fn from(value: Frequency) -> Self { + value.inner + } +} + +impl From for Frequency { + fn from(value: f64) -> Self { + Frequency { inner: value } + } +} + +impl Sum for Frequency { + fn sum>(iter: I) -> Self { + iter.fold(Self { inner: 0.0 }, |a, b| Self { + inner: a.inner + b.inner, + }) + } +} + +impl Display for Frequency { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write(f, format_args!("{}", self.inner)) + } }