From beeea3597ecec95de29d192d486ffe390f9300c1 Mon Sep 17 00:00:00 2001 From: Jay Robson Date: Sat, 10 May 2025 18:05:48 +1000 Subject: [PATCH] added optimisations --- src/ballot.rs | 28 +++++++++----- src/counter.rs | 90 +++++++++++++++++-------------------------- src/csv/reader.rs | 77 +++++++++--------------------------- src/csv/reader/row.rs | 85 ++++++++++++++++++++++++++++++++++++++++ src/header.rs | 8 ++-- src/main.rs | 7 ++-- src/util.rs | 2 + src/util/mutex.rs | 17 ++++++++ src/util/thread.rs | 31 +++++++++++++++ 9 files changed, 216 insertions(+), 129 deletions(-) create mode 100644 src/csv/reader/row.rs create mode 100644 src/util/mutex.rs create mode 100644 src/util/thread.rs diff --git a/src/ballot.rs b/src/ballot.rs index 107afc6..e58f7bf 100644 --- a/src/ballot.rs +++ b/src/ballot.rs @@ -1,6 +1,6 @@ use itertools::Itertools; -use crate::{csv, header::Header, stats::Stats}; +use crate::{csv, header::Header, stats::Stats, util::ScoreItem}; #[derive(Debug)] @@ -12,7 +12,7 @@ pub struct Ballot { const CANDIDATE_MIN: usize = 6; impl Ballot { - pub fn parse(mut row: csv::reader::RowReader, header: &Header, stats: &mut Stats) -> Result> { + pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result> { let mut cols = row.by_ref().skip(6); let place_filter = |(index, place): (usize, String)| match place.parse::() { @@ -45,22 +45,32 @@ impl Ballot { } else { - println!("abl votes: {votes_party:?}"); - println!("btl votes: {votes_candidate:?}"); - println!("at: {}", row.get_line_number()); return Err("ballot is informal".into()); } stats.total += 1; Ok(Ballot { weight: 1.0, votes }) } - pub fn get_weight(&self) -> f64 { - self.weight + pub fn count_primary(&self, scores: &mut [f64], enabled: &[bool]) { + if let Some(index) = self.get_primary_index(enabled) { + scores[index] += self.weight; + } } - pub fn apply_weight(&mut self, weight: f64) { + pub fn apply_winners(&mut self, enabled: &[bool], winners: &[ScoreItem], quota: f64) { + let index = match self.get_primary_index(enabled) { + Some(v) => v, + None => return, + }; + if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() { + self.apply_weight((winner.value - quota) / winner.value); + } + } + #[inline] + fn apply_weight(&mut self, weight: f64) { self.weight *= weight; } - pub fn get_primary_index(&self, enabled: &[bool]) -> Option { + #[inline] + fn get_primary_index(&self, enabled: &[bool]) -> Option { for &id in self.votes.iter() { if enabled[id] { return Some(id); diff --git a/src/counter.rs b/src/counter.rs index d2054dd..729a12c 100644 --- a/src/counter.rs +++ b/src/counter.rs @@ -1,7 +1,7 @@ use std::{sync::{Arc, Mutex}, thread}; use itertools::Itertools; -use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::ScoreItem}; +use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::{self, ScoreItem}}; #[derive(Debug)] pub enum Event { @@ -19,10 +19,10 @@ pub struct Counter { quota: f64, } +const CHUNK_SIZE: usize = 1024; + impl Counter { - pub fn new<'a,I>(mut csv: I, winners: usize) -> Result> - where I: Iterator> + Send + Sync - { + pub fn new(mut csv: csv::reader::CsvReader, winners: usize) -> Result> { let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")?, winners)?); let ballots = Mutex::new(Vec::new()); let stats = Mutex::new(Stats::new()); @@ -32,22 +32,25 @@ impl Counter { return Err("winners can't be smaller than the candidates list".into()); } - thread::scope(|s| { - let threads = std::thread::available_parallelism().unwrap().into(); - for _ in 0..threads { - s.spawn(|| { - let mut l_stats = Stats::new(); - let mut l_ballots = Vec::new(); - while let Some(mut row) = { csv.lock().unwrap().next() } { - if !row.check_empty() { - l_ballots.push(Ballot::parse(row, &header, &mut l_stats).unwrap()); - } - } - {ballots.lock().unwrap().append(&mut l_ballots)}; - {stats.lock().unwrap().add(&l_stats)}; - }); + util::thread::spawn_all_with_result(|| -> Result<(), Box> { + let mut l_stats = Stats::new(); + let mut l_ballots = Vec::new(); + loop { + let rows = csv.lock().unwrap().by_ref().take(CHUNK_SIZE).collect_vec(); + + if rows.len() == 0 { + break; + } + + for row in rows { + l_ballots.push(Ballot::parse(row, &header, &mut l_stats).unwrap()); + } + + ballots.lock().unwrap().append(&mut l_ballots); } - }); + {stats.lock().unwrap().add(&l_stats)}; + Ok(()) + })?; let ballots = ballots.into_inner().unwrap(); let stats = stats.into_inner().unwrap(); @@ -67,24 +70,15 @@ impl Counter { } fn count_primaries(&self) -> Vec { let mut scores = Mutex::new(vec![0.0; self.enabled.len()]); - let ballot_chunks = Mutex::new(self.ballots.chunks(1024)); + let ballot_chunks = Mutex::new(self.ballots.chunks(CHUNK_SIZE)); - thread::scope(|s| { - let threads = std::thread::available_parallelism().unwrap().into(); - for _ in 0..threads { - s.spawn(|| { - let mut l_scores = vec![0.0; self.enabled.len()]; - while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } { - for ballot in ballots { - if let Some(index) = ballot.get_primary_index(&self.enabled) { - l_scores[index] += ballot.get_weight(); - } - } - } - for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) { - *dst += src; - } - }); + util::thread::spawn_all(|| { + let mut l_scores = vec![0.0; self.enabled.len()]; + for ballot in util::mutex::chunk_iter(&ballot_chunks) { + ballot.count_primary(&mut l_scores, &self.enabled); + } + for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) { + *dst += src; } }); @@ -94,27 +88,15 @@ impl Counter { }).collect_vec() } fn apply_weights(&mut self, winners: &[ScoreItem]) { - let ballot_chunks = Mutex::new(self.ballots.chunks_mut(1024)); - thread::scope(|s| { - let threads = std::thread::available_parallelism().unwrap().into(); - for _ in 0..threads { - s.spawn(|| { - while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } { - for ballot in ballots { - let index = match ballot.get_primary_index(&self.enabled) { - Some(v) => v, - None => continue, - }; - if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() { - let surplus = winner.value - self.quota; - ballot.apply_weight(surplus / winner.value); - } - } - } - }); + let ballot_chunks = Mutex::new(self.ballots.chunks_mut(CHUNK_SIZE)); + + util::thread::spawn_all(|| { + for ballot in util::mutex::chunk_iter_mut(&ballot_chunks) { + ballot.apply_winners(&self.enabled, winners, self.quota); } }); } + #[inline] pub fn get_quota(&self) -> f64 { self.quota } diff --git a/src/csv/reader.rs b/src/csv/reader.rs index 9669320..6f0707e 100644 --- a/src/csv/reader.rs +++ b/src/csv/reader.rs @@ -1,74 +1,33 @@ -use std::{iter::Peekable, str::Chars}; +use std::str::Split; +pub use row::Row; +pub mod row; -pub struct RowReader<'a> { - it: Peekable>, +pub struct CsvReader<'a> { + it: Split<'a, char>, delimiter: char, - ended: bool, - at: usize, } -impl<'a> RowReader<'a> { - pub fn new(line: &'a str, at: usize, delimiter: char) -> Self { - Self { it: line.chars().peekable(), at, delimiter, ended: false } - } - pub fn get_line_number(&self) -> usize { - self.at - } - pub fn check_empty(&mut self) -> bool { - self.it.peek().is_none() +impl<'a> CsvReader<'a> { + pub fn new(text: &'a str, delimiter: char) -> Self { + Self {it: text.split('\n'), delimiter } } } -impl<'a> Iterator for RowReader<'a> { - type Item = String; +impl<'a> Iterator for CsvReader<'a> { + type Item = Row<'a>; fn next(&mut self) -> Option { - if self.ended { - return None; - } - - let mut value = String::new(); - let mut escaped = false; - let mut end_quote = false; - let can_escape = self.it.peek().copied() == Some('"'); - - if can_escape { - self.it.next(); - } - - for ch in self.it.by_ref() { - if escaped { - value.push(ch); - escaped = false; - continue; + for line in self.it.by_ref() { + let row = Row::new(line, self.delimiter); + if row.has_next() { + return Some(row); } - if can_escape { - if ch == '\\' { - escaped = true; - continue; - } - if ch == '"' { - if end_quote { - value.push(ch); - } - end_quote = !end_quote; - continue; - } - } - if !can_escape || end_quote { - if ch == '\r' { - continue; - } - if ch == self.delimiter { - return Some(value); - } - } - value.push(ch); } - - self.ended = true; - Some(value) + None + } + fn size_hint(&self) -> (usize, Option) { + (0, self.it.size_hint().1) } } diff --git a/src/csv/reader/row.rs b/src/csv/reader/row.rs new file mode 100644 index 0000000..f5fb693 --- /dev/null +++ b/src/csv/reader/row.rs @@ -0,0 +1,85 @@ +use std::{iter::Peekable, str::Chars}; + + +pub struct Row<'a> { + it: Peekable>, + delimiter: char, + ended: bool, +} + +impl<'a> Row<'a> { + pub fn new(line: &'a str, delimiter: char) -> Self { + let mut it = line.chars().peekable(); + let ended = match it.peek().copied() { + Some('\r') => true, + Some(_) => false, + None => true, + }; + Row { it, ended, delimiter } + } +} + +impl<'a> Row<'a> { + pub fn has_next(&self) -> bool { + !self.ended + } +} + +impl<'a> Iterator for Row<'a> { + type Item = String; + + fn next(&mut self) -> Option { + if self.ended { + return None; + } + + let mut value = String::new(); + let mut escaped = false; + let mut end_quote = false; + let can_escape = self.it.peek().copied() == Some('"'); + + if can_escape { + self.it.next(); + } + + for ch in self.it.by_ref() { + if escaped { + value.push(ch); + escaped = false; + continue; + } + if can_escape { + if ch == '\\' { + escaped = true; + continue; + } + if ch == '"' { + if end_quote { + value.push(ch); + } + end_quote = !end_quote; + continue; + } + } + if !can_escape || end_quote { + if ch == '\r' { + continue; + } + if ch == self.delimiter { + return Some(value); + } + } + value.push(ch); + } + + self.ended = true; + Some(value) + } + fn size_hint(&self) -> (usize, Option) { + match self.ended { + true => (0, Some(0)), + false => (1, None), + } + } +} + diff --git a/src/header.rs b/src/header.rs index 7a341c6..b278166 100644 --- a/src/header.rs +++ b/src/header.rs @@ -11,12 +11,12 @@ pub struct Header { } impl Header { - pub fn parse(row: csv::reader::RowReader, winners: usize) -> Result> { + pub fn parse(row: csv::reader::Row, winners: usize) -> Result> { let mut parties = Vec::::new(); let mut candidates = Vec::::new(); - let mut parties_lookup = HashMap::<&str, usize>::new(); + let mut parties_lookup = HashMap::::new(); - for col in row.skip(6).collect_vec().iter() { + for col in row.skip(6) { let [party, name] = col.split(':').next_array().ok_or("Missing ':'")?; if party == "UG" { // independents @@ -33,7 +33,7 @@ impl Header { }); } else { - parties_lookup.insert(party, parties.len()); + parties_lookup.insert(party.to_owned(), parties.len()); parties.push(Party::new(name.to_owned())); } } diff --git a/src/main.rs b/src/main.rs index c934cd2..e1f2bd0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use std::{env, fs}; use counter::Counter; +use csv::reader::CsvReader; use util::Percent; pub mod util; @@ -22,10 +23,10 @@ fn main() { (args[1].clone(), args[2].parse::().unwrap()) }; - let csv = String::from_utf8(fs::read(&csv_path).unwrap()).unwrap(); - let rows = csv.split('\n').enumerate().map(|(i,v)| csv::reader::RowReader::new(v, i, ',')); + let csv_text = String::from_utf8(fs::read(&csv_path).unwrap()).unwrap(); + let csv = CsvReader::new(&csv_text, ','); - let counter = Counter::new(rows, winner_count).unwrap(); + let counter = Counter::new(csv, winner_count).unwrap(); let mut winners = Vec::with_capacity(winner_count); let total = counter.ballots.len() as f64; let header = counter.header.clone(); diff --git a/src/util.rs b/src/util.rs index e725187..0e5cf7a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,6 +2,8 @@ pub mod escape; pub mod percent; pub mod score_item; +pub mod mutex; +pub mod thread; pub use score_item::ScoreItem; pub use escape::{EscapeWriter, EscapeWriterOpts}; diff --git a/src/util/mutex.rs b/src/util/mutex.rs new file mode 100644 index 0000000..0638d4d --- /dev/null +++ b/src/util/mutex.rs @@ -0,0 +1,17 @@ +use std::{slice::{Chunks, ChunksMut}, sync::Mutex}; + + +pub fn chunk_iter<'a,'b,T>(mtx: &'a Mutex>) -> impl Iterator { + std::iter::repeat_with(|| { mtx.lock().unwrap().next() }).take_while(|v| v.is_some()).filter_map(|v| match v { + Some(v) => Some(v.iter()), + None => None, + }).flatten() +} + +pub fn chunk_iter_mut<'a,'b,T>(mtx: &'a Mutex>) -> impl Iterator { + std::iter::repeat_with(|| { mtx.lock().unwrap().next() }).take_while(|v| v.is_some()).filter_map(|v| match v { + Some(v) => Some(v.iter_mut()), + None => None, + }).flatten() +} + diff --git a/src/util/thread.rs b/src/util/thread.rs new file mode 100644 index 0000000..14fcfa8 --- /dev/null +++ b/src/util/thread.rs @@ -0,0 +1,31 @@ +use std::sync::Mutex; + +#[inline] +pub fn spawn_all(f: F) where F: Fn() + Sync { + std::thread::scope(|s| { + for _ in 0..std::thread::available_parallelism().unwrap().into() { + s.spawn(|| f()); + } + }); +} + +#[inline] +pub fn spawn_all_with_result(f: F) -> Result<(), T> + where + F: Fn() -> Result<(), T> + Sync, + T: Send +{ + let res = Mutex::new(None); + + spawn_all(|| { + if let Err(err) = f() { + *res.lock().unwrap() = Some(err); + } + }); + + match res.into_inner().unwrap() { + Some(err) => Err(err), + None => Ok(()), + } +} +