diff --git a/src/ballot.rs b/src/ballot.rs index 1c8b3a9..3001ce5 100644 --- a/src/ballot.rs +++ b/src/ballot.rs @@ -57,11 +57,9 @@ impl Ballot { pub fn apply_weight(&mut self, weight: f64) { self.weight *= weight; } - pub fn get_primary_index(&self, enabled: F) -> Option - where F: Fn(usize) -> bool - { + pub fn get_primary_index(&self, enabled: &[bool]) -> Option { for &id in self.votes.iter() { - if enabled(id) { + if enabled[id] { return Some(id); } } diff --git a/src/counter.rs b/src/counter.rs index 7c55396..7d06779 100644 --- a/src/counter.rs +++ b/src/counter.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{sync::{atomic::{AtomicUsize, Ordering}, mpsc, Arc, Mutex}, thread}; use itertools::Itertools; use crate::{ballot::Ballot, header::Header, stats::Stats, util::ScoreItem}; @@ -11,7 +11,7 @@ pub enum Event { #[derive(Debug)] pub struct Counter { - pub header: Rc
, + pub header: Arc
, pub ballots: Vec, pub stats: Stats, winners_left: usize, @@ -21,7 +21,7 @@ pub struct Counter { impl Counter { pub fn new(mut csv: quick_csv::Csv, winners: usize) -> Result> { - let header = Rc::new(Header::parse(csv.next().ok_or("csv header missing")??, winners)?); + let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")??, winners)?); let mut ballots = Vec::new(); let mut stats = Stats::new(); @@ -47,17 +47,55 @@ impl Counter { }) } fn count_primaries(&self) -> Vec { - let mut scores = vec![0.0; self.enabled.len()]; - for ballot in self.ballots.iter() { - if let Some(index) = ballot.get_primary_index(|id| self.enabled[id]) { - scores[index] += ballot.get_weight(); + let mut scores = Mutex::new(vec![0.0; self.enabled.len()]); + let ballot_chunks = Mutex::new(self.ballots.chunks(1024)); + + thread::scope(|s| { + let threads = std::thread::available_parallelism().unwrap().into(); + for _ in 0..threads { + s.spawn(|| { + let mut l_scores = vec![0.0; self.enabled.len()]; + while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } { + for ballot in ballots { + if let Some(index) = ballot.get_primary_index(&self.enabled) { + l_scores[index] += ballot.get_weight(); + } + } + } + for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) { + *dst += src; + } + }); } - } - scores.into_iter().enumerate().filter_map(|(index, value)| match self.enabled[index] { + }); + + scores.get_mut().unwrap().iter().copied().enumerate().filter_map(|(index, value)| match self.enabled[index] { true => Some(ScoreItem::new(index, value)), false => None, }).collect_vec() } + fn apply_weights(&mut self, winners: &[ScoreItem]) { + let ballot_chunks = Mutex::new(self.ballots.chunks_mut(1024)); + thread::scope(|s| { + let threads = std::thread::available_parallelism().unwrap().into(); + for _ in 0..threads { + s.spawn(|| { + while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } { + for ballot in ballots { + let index = match ballot.get_primary_index(&self.enabled) { + Some(v) => v, + None => continue, + }; + if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() { + let surplus = winner.value - self.quota; + ballot.apply_weight(surplus / winner.value); + } + } + } + }); + } + }); + } pub fn get_quota(&self) -> f64 { self.quota } @@ -89,19 +127,9 @@ impl Iterator for Counter { } winners.sort_by(|a,b| b.cmp(a)); - - for ballot in self.ballots.iter_mut() { - let index = match ballot.get_primary_index(|id| self.enabled[id]) { - Some(v) => v, - None => continue, - }; - if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() { - let surplus = winner.value - self.quota; - ballot.apply_weight(surplus / winner.value); - } - } - + self.apply_weights(&winners); self.winners_left -= winners.len(); + for winner in winners.iter() { self.enabled[winner.index] = false; }