added optimisations

This commit is contained in:
Jay Robson 2025-05-10 18:05:48 +10:00
parent 17a6c88dbc
commit beeea3597e
9 changed files with 216 additions and 129 deletions

View File

@ -1,6 +1,6 @@
use itertools::Itertools; use itertools::Itertools;
use crate::{csv, header::Header, stats::Stats}; use crate::{csv, header::Header, stats::Stats, util::ScoreItem};
#[derive(Debug)] #[derive(Debug)]
@ -12,7 +12,7 @@ pub struct Ballot {
const CANDIDATE_MIN: usize = 6; const CANDIDATE_MIN: usize = 6;
impl Ballot { impl Ballot {
pub fn parse(mut row: csv::reader::RowReader, header: &Header, stats: &mut Stats) -> Result<Ballot, Box<dyn std::error::Error>> { pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result<Ballot, Box<dyn std::error::Error>> {
let mut cols = row.by_ref().skip(6); let mut cols = row.by_ref().skip(6);
let place_filter = |(index, place): (usize, String)| match place.parse::<i32>() { let place_filter = |(index, place): (usize, String)| match place.parse::<i32>() {
@ -45,22 +45,32 @@ impl Ballot {
} }
else { else {
println!("abl votes: {votes_party:?}");
println!("btl votes: {votes_candidate:?}");
println!("at: {}", row.get_line_number());
return Err("ballot is informal".into()); return Err("ballot is informal".into());
} }
stats.total += 1; stats.total += 1;
Ok(Ballot { weight: 1.0, votes }) Ok(Ballot { weight: 1.0, votes })
} }
pub fn get_weight(&self) -> f64 { pub fn count_primary(&self, scores: &mut [f64], enabled: &[bool]) {
self.weight if let Some(index) = self.get_primary_index(enabled) {
scores[index] += self.weight;
}
} }
pub fn apply_weight(&mut self, weight: f64) { pub fn apply_winners(&mut self, enabled: &[bool], winners: &[ScoreItem], quota: f64) {
let index = match self.get_primary_index(enabled) {
Some(v) => v,
None => return,
};
if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() {
self.apply_weight((winner.value - quota) / winner.value);
}
}
#[inline]
fn apply_weight(&mut self, weight: f64) {
self.weight *= weight; self.weight *= weight;
} }
pub fn get_primary_index(&self, enabled: &[bool]) -> Option<usize> { #[inline]
fn get_primary_index(&self, enabled: &[bool]) -> Option<usize> {
for &id in self.votes.iter() { for &id in self.votes.iter() {
if enabled[id] { if enabled[id] {
return Some(id); return Some(id);

View File

@ -1,7 +1,7 @@
use std::{sync::{Arc, Mutex}, thread}; use std::{sync::{Arc, Mutex}, thread};
use itertools::Itertools; use itertools::Itertools;
use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::ScoreItem}; use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::{self, ScoreItem}};
#[derive(Debug)] #[derive(Debug)]
pub enum Event { pub enum Event {
@ -19,10 +19,10 @@ pub struct Counter {
quota: f64, quota: f64,
} }
const CHUNK_SIZE: usize = 1024;
impl Counter { impl Counter {
pub fn new<'a,I>(mut csv: I, winners: usize) -> Result<Self, Box<dyn std::error::Error>> pub fn new(mut csv: csv::reader::CsvReader, winners: usize) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
where I: Iterator<Item=csv::reader::RowReader<'a>> + Send + Sync
{
let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")?, winners)?); let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")?, winners)?);
let ballots = Mutex::new(Vec::new()); let ballots = Mutex::new(Vec::new());
let stats = Mutex::new(Stats::new()); let stats = Mutex::new(Stats::new());
@ -32,22 +32,25 @@ impl Counter {
return Err("winners can't be smaller than the candidates list".into()); return Err("winners can't be smaller than the candidates list".into());
} }
thread::scope(|s| { util::thread::spawn_all_with_result(|| -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let threads = std::thread::available_parallelism().unwrap().into(); let mut l_stats = Stats::new();
for _ in 0..threads { let mut l_ballots = Vec::new();
s.spawn(|| { loop {
let mut l_stats = Stats::new(); let rows = csv.lock().unwrap().by_ref().take(CHUNK_SIZE).collect_vec();
let mut l_ballots = Vec::new();
while let Some(mut row) = { csv.lock().unwrap().next() } { if rows.len() == 0 {
if !row.check_empty() { break;
l_ballots.push(Ballot::parse(row, &header, &mut l_stats).unwrap()); }
}
} for row in rows {
{ballots.lock().unwrap().append(&mut l_ballots)}; l_ballots.push(Ballot::parse(row, &header, &mut l_stats).unwrap());
{stats.lock().unwrap().add(&l_stats)}; }
});
ballots.lock().unwrap().append(&mut l_ballots);
} }
}); {stats.lock().unwrap().add(&l_stats)};
Ok(())
})?;
let ballots = ballots.into_inner().unwrap(); let ballots = ballots.into_inner().unwrap();
let stats = stats.into_inner().unwrap(); let stats = stats.into_inner().unwrap();
@ -67,24 +70,15 @@ impl Counter {
} }
fn count_primaries(&self) -> Vec<ScoreItem> { fn count_primaries(&self) -> Vec<ScoreItem> {
let mut scores = Mutex::new(vec![0.0; self.enabled.len()]); let mut scores = Mutex::new(vec![0.0; self.enabled.len()]);
let ballot_chunks = Mutex::new(self.ballots.chunks(1024)); let ballot_chunks = Mutex::new(self.ballots.chunks(CHUNK_SIZE));
thread::scope(|s| { util::thread::spawn_all(|| {
let threads = std::thread::available_parallelism().unwrap().into(); let mut l_scores = vec![0.0; self.enabled.len()];
for _ in 0..threads { for ballot in util::mutex::chunk_iter(&ballot_chunks) {
s.spawn(|| { ballot.count_primary(&mut l_scores, &self.enabled);
let mut l_scores = vec![0.0; self.enabled.len()]; }
while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } { for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) {
for ballot in ballots { *dst += src;
if let Some(index) = ballot.get_primary_index(&self.enabled) {
l_scores[index] += ballot.get_weight();
}
}
}
for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) {
*dst += src;
}
});
} }
}); });
@ -94,27 +88,15 @@ impl Counter {
}).collect_vec() }).collect_vec()
} }
fn apply_weights(&mut self, winners: &[ScoreItem]) { fn apply_weights(&mut self, winners: &[ScoreItem]) {
let ballot_chunks = Mutex::new(self.ballots.chunks_mut(1024)); let ballot_chunks = Mutex::new(self.ballots.chunks_mut(CHUNK_SIZE));
thread::scope(|s| {
let threads = std::thread::available_parallelism().unwrap().into(); util::thread::spawn_all(|| {
for _ in 0..threads { for ballot in util::mutex::chunk_iter_mut(&ballot_chunks) {
s.spawn(|| { ballot.apply_winners(&self.enabled, winners, self.quota);
while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } {
for ballot in ballots {
let index = match ballot.get_primary_index(&self.enabled) {
Some(v) => v,
None => continue,
};
if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() {
let surplus = winner.value - self.quota;
ballot.apply_weight(surplus / winner.value);
}
}
}
});
} }
}); });
} }
#[inline]
pub fn get_quota(&self) -> f64 { pub fn get_quota(&self) -> f64 {
self.quota self.quota
} }

View File

@ -1,74 +1,33 @@
use std::{iter::Peekable, str::Chars}; use std::str::Split;
pub use row::Row;
pub mod row;
pub struct RowReader<'a> { pub struct CsvReader<'a> {
it: Peekable<Chars<'a>>, it: Split<'a, char>,
delimiter: char, delimiter: char,
ended: bool,
at: usize,
} }
impl<'a> RowReader<'a> { impl<'a> CsvReader<'a> {
pub fn new(line: &'a str, at: usize, delimiter: char) -> Self { pub fn new(text: &'a str, delimiter: char) -> Self {
Self { it: line.chars().peekable(), at, delimiter, ended: false } Self {it: text.split('\n'), delimiter }
}
pub fn get_line_number(&self) -> usize {
self.at
}
pub fn check_empty(&mut self) -> bool {
self.it.peek().is_none()
} }
} }
impl<'a> Iterator for RowReader<'a> { impl<'a> Iterator for CsvReader<'a> {
type Item = String; type Item = Row<'a>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.ended { for line in self.it.by_ref() {
return None; let row = Row::new(line, self.delimiter);
} if row.has_next() {
return Some(row);
let mut value = String::new();
let mut escaped = false;
let mut end_quote = false;
let can_escape = self.it.peek().copied() == Some('"');
if can_escape {
self.it.next();
}
for ch in self.it.by_ref() {
if escaped {
value.push(ch);
escaped = false;
continue;
} }
if can_escape {
if ch == '\\' {
escaped = true;
continue;
}
if ch == '"' {
if end_quote {
value.push(ch);
}
end_quote = !end_quote;
continue;
}
}
if !can_escape || end_quote {
if ch == '\r' {
continue;
}
if ch == self.delimiter {
return Some(value);
}
}
value.push(ch);
} }
None
self.ended = true; }
Some(value) fn size_hint(&self) -> (usize, Option<usize>) {
(0, self.it.size_hint().1)
} }
} }

85
src/csv/reader/row.rs Normal file
View File

@ -0,0 +1,85 @@
use std::{iter::Peekable, str::Chars};
pub struct Row<'a> {
it: Peekable<Chars<'a>>,
delimiter: char,
ended: bool,
}
impl<'a> Row<'a> {
pub fn new(line: &'a str, delimiter: char) -> Self {
let mut it = line.chars().peekable();
let ended = match it.peek().copied() {
Some('\r') => true,
Some(_) => false,
None => true,
};
Row { it, ended, delimiter }
}
}
impl<'a> Row<'a> {
pub fn has_next(&self) -> bool {
!self.ended
}
}
impl<'a> Iterator for Row<'a> {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if self.ended {
return None;
}
let mut value = String::new();
let mut escaped = false;
let mut end_quote = false;
let can_escape = self.it.peek().copied() == Some('"');
if can_escape {
self.it.next();
}
for ch in self.it.by_ref() {
if escaped {
value.push(ch);
escaped = false;
continue;
}
if can_escape {
if ch == '\\' {
escaped = true;
continue;
}
if ch == '"' {
if end_quote {
value.push(ch);
}
end_quote = !end_quote;
continue;
}
}
if !can_escape || end_quote {
if ch == '\r' {
continue;
}
if ch == self.delimiter {
return Some(value);
}
}
value.push(ch);
}
self.ended = true;
Some(value)
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.ended {
true => (0, Some(0)),
false => (1, None),
}
}
}

View File

@ -11,12 +11,12 @@ pub struct Header {
} }
impl Header { impl Header {
pub fn parse(row: csv::reader::RowReader, winners: usize) -> Result<Header, Box<dyn std::error::Error>> { pub fn parse(row: csv::reader::Row, winners: usize) -> Result<Header, Box<dyn std::error::Error + Sync + Send>> {
let mut parties = Vec::<Party>::new(); let mut parties = Vec::<Party>::new();
let mut candidates = Vec::<Candidate>::new(); let mut candidates = Vec::<Candidate>::new();
let mut parties_lookup = HashMap::<&str, usize>::new(); let mut parties_lookup = HashMap::<String, usize>::new();
for col in row.skip(6).collect_vec().iter() { for col in row.skip(6) {
let [party, name] = col.split(':').next_array().ok_or("Missing ':'")?; let [party, name] = col.split(':').next_array().ok_or("Missing ':'")?;
if party == "UG" { // independents if party == "UG" { // independents
@ -33,7 +33,7 @@ impl Header {
}); });
} }
else { else {
parties_lookup.insert(party, parties.len()); parties_lookup.insert(party.to_owned(), parties.len());
parties.push(Party::new(name.to_owned())); parties.push(Party::new(name.to_owned()));
} }
} }

View File

@ -1,6 +1,7 @@
use std::{env, fs}; use std::{env, fs};
use counter::Counter; use counter::Counter;
use csv::reader::CsvReader;
use util::Percent; use util::Percent;
pub mod util; pub mod util;
@ -22,10 +23,10 @@ fn main() {
(args[1].clone(), args[2].parse::<usize>().unwrap()) (args[1].clone(), args[2].parse::<usize>().unwrap())
}; };
let csv = String::from_utf8(fs::read(&csv_path).unwrap()).unwrap(); let csv_text = String::from_utf8(fs::read(&csv_path).unwrap()).unwrap();
let rows = csv.split('\n').enumerate().map(|(i,v)| csv::reader::RowReader::new(v, i, ',')); let csv = CsvReader::new(&csv_text, ',');
let counter = Counter::new(rows, winner_count).unwrap(); let counter = Counter::new(csv, winner_count).unwrap();
let mut winners = Vec::with_capacity(winner_count); let mut winners = Vec::with_capacity(winner_count);
let total = counter.ballots.len() as f64; let total = counter.ballots.len() as f64;
let header = counter.header.clone(); let header = counter.header.clone();

View File

@ -2,6 +2,8 @@
pub mod escape; pub mod escape;
pub mod percent; pub mod percent;
pub mod score_item; pub mod score_item;
pub mod mutex;
pub mod thread;
pub use score_item::ScoreItem; pub use score_item::ScoreItem;
pub use escape::{EscapeWriter, EscapeWriterOpts}; pub use escape::{EscapeWriter, EscapeWriterOpts};

17
src/util/mutex.rs Normal file
View File

@ -0,0 +1,17 @@
use std::{slice::{Chunks, ChunksMut}, sync::Mutex};
pub fn chunk_iter<'a,'b,T>(mtx: &'a Mutex<Chunks<'b,T>>) -> impl Iterator<Item=&'b T> {
std::iter::repeat_with(|| { mtx.lock().unwrap().next() }).take_while(|v| v.is_some()).filter_map(|v| match v {
Some(v) => Some(v.iter()),
None => None,
}).flatten()
}
pub fn chunk_iter_mut<'a,'b,T>(mtx: &'a Mutex<ChunksMut<'b,T>>) -> impl Iterator<Item=&'b mut T> {
std::iter::repeat_with(|| { mtx.lock().unwrap().next() }).take_while(|v| v.is_some()).filter_map(|v| match v {
Some(v) => Some(v.iter_mut()),
None => None,
}).flatten()
}

31
src/util/thread.rs Normal file
View File

@ -0,0 +1,31 @@
use std::sync::Mutex;
#[inline]
pub fn spawn_all<F>(f: F) where F: Fn() + Sync {
std::thread::scope(|s| {
for _ in 0..std::thread::available_parallelism().unwrap().into() {
s.spawn(|| f());
}
});
}
#[inline]
pub fn spawn_all_with_result<F, T>(f: F) -> Result<(), T>
where
F: Fn() -> Result<(), T> + Sync,
T: Send
{
let res = Mutex::new(None);
spawn_all(|| {
if let Err(err) = f() {
*res.lock().unwrap() = Some(err);
}
});
match res.into_inner().unwrap() {
Some(err) => Err(err),
None => Ok(()),
}
}