diff --git a/src/main.rs b/src/main.rs index d2953e7..4678d59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,6 +19,10 @@ struct Cli { /// Machine output #[arg(short, long, default_value_t = false, global = true)] machine: bool, + + /// Verbose output + #[arg(short, long, default_value_t = false, global = true)] + verbose: bool, } #[derive(Subcommand, Debug)] @@ -79,7 +83,7 @@ pub fn main() { let b = num_bigint::BigInt::from_str(&mod_exp_args.base.as_str()).expect("could not make bigint"); let e = num_bigint::BigInt::from_str(&mod_exp_args.exp.as_str()).expect("could not make bigint"); let f = num_bigint::BigInt::from_str(&mod_exp_args.field.as_str()).expect("could not make bigint"); - let result = modular_exponentiation(b.clone(), e, f); + let result = modular_exponentiation(b.clone(), e, f, args.verbose); if args.machine { println!("{}", result) } diff --git a/src/modular_exponentiation.rs b/src/modular_exponentiation.rs index 19f76cb..588b38c 100644 --- a/src/modular_exponentiation.rs +++ b/src/modular_exponentiation.rs @@ -1,9 +1,7 @@ #![allow(dead_code)] -use num_bigint::BigInt; +use num_bigint::{BigInt, BigUint}; use num_traits::ToPrimitive; -use num_traits::FromPrimitive; -use pyo3::exceptions::PyException; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -16,48 +14,64 @@ pub fn calc_exp_in_field_lib( } /** - * modular exponentiation algorithm with big numbers. - * - * Umwandlung des Exponenten k in die zugehörige Binärdarstellung. - * Ersetzen jeder 0 durch Q und jeder 1 durch QM. - * Nun wird Q als Anweisung zum Quadrieren und M als Anweisung zum Multiplizieren aufgefasst. - * Somit bildet die resultierende Zeichenkette von links nach rechts gelesen eine Vorschrift zur Berechnung von x k . - * Man beginne mit 1, quadriere für jedes gelesene Q das bisherige Zwischenergebnis und - * multipliziere es für jedes gelesene M mit x . + * square and multiply */ pub fn modular_exponentiation( base: BigInt, - orig_exp: BigInt, - field: BigInt) -> BigInt { - let binary_repr = orig_exp.to_bytes_be(); + exp: BigInt, + field: BigInt, + verbose: bool) -> BigInt { + if verbose { + println!("args:\nbase {base}\nexp {exp}\nfield {field}\nverbose {verbose}"); + } + let mut instructions: Vec = bigint_to_bools(exp); + // remove the signing bit + if verbose { + println!("pre instructions {:?}",instructions); + } - let instructions: Vec = bytes_to_bools(&binary_repr.1); + instructions.reverse(); + if verbose { + println!("instructions {:?}",instructions); + } - let mut exp = BigInt::from(1); + let mut res = base.clone(); for instr in instructions { - if instr { + if verbose { + println!("current res: {res}"); + } + if !instr { // square - exp = (exp.pow(2) * &base) % &field; + if verbose { + println!("square"); + } + res = res.pow(2) % &field; } else { // square and multiply - exp = exp.pow(2) % &field; + if verbose { + println!("square and multiply"); + } + res = (res.pow(2) * &base) % &field; } } - return exp; + return res; } #[pyfunction] #[pyo3(name="modular_exponentiation")] +#[pyo3(signature=(base, orig_exp, field, verbose = false))] pub fn py_modular_exponentiation( base: i128, orig_exp: i128, - field: i128) -> PyResult { + field: i128, + verbose: bool) -> PyResult { let big_res = modular_exponentiation( BigInt::from(base), BigInt::from(orig_exp), - BigInt::from(field) + BigInt::from(field), + verbose ); let res = big_res.to_u128(); match res { @@ -71,17 +85,46 @@ pub fn py_modular_exponentiation( } -// Vec to Vec ( binary representation interpreted otherwise ) -fn bytes_to_bools(bytes: &Vec) -> Vec { +/// Dont use this buggy mess +pub fn binary_exponentiation(base: BigInt, exp: BigInt, verbose: bool) -> BigInt { + if exp.clone() < BigInt::from(0) { + return binary_exponentiation(1/&base, -exp, verbose); + } + else if exp.clone() == BigInt::from(0) { + return BigInt::from(1); + } + else if exp.clone() % 2 == BigInt::from(0) { + return binary_exponentiation(&base*&base, &exp/2, verbose); + } + else if exp.clone() % 2 == BigInt::from(1) { + return binary_exponentiation(&base*&base, (&exp-1)/2, verbose); + } + else { + panic!("I don't know how we got here") + } +} + +fn bigint_to_bools(item: BigInt) -> Vec { let mut result: Vec = Vec::new(); - for byte in bytes { - for c in format!("{:b}", byte).chars() { - result.push(c == '1'); + let mut modul : BigInt; + let mut smaller = item; + loop { + if smaller < BigInt::from(2) { + break; + } + modul = &smaller % BigInt::from(2); + smaller = &smaller / BigInt::from(2); + if modul == BigInt::from(1) { + result.push(true); + } + else { + result.push(false); } } result } + fn dump_bin(bytes: &Vec) { for byte in bytes.iter() { println!("{:#08b}\t| {:#02x}", byte, byte);