added multiprocessing

This commit is contained in:
Jay Robson 2025-05-09 22:05:13 +10:00
parent 9d8d601c6a
commit 2256ea367c
2 changed files with 51 additions and 25 deletions

View File

@ -57,11 +57,9 @@ impl Ballot {
pub fn apply_weight(&mut self, weight: f64) { pub fn apply_weight(&mut self, weight: f64) {
self.weight *= weight; self.weight *= weight;
} }
pub fn get_primary_index<F>(&self, enabled: F) -> Option<usize> pub fn get_primary_index(&self, enabled: &[bool]) -> Option<usize> {
where F: Fn(usize) -> bool
{
for &id in self.votes.iter() { for &id in self.votes.iter() {
if enabled(id) { if enabled[id] {
return Some(id); return Some(id);
} }
} }

View File

@ -1,4 +1,4 @@
use std::rc::Rc; use std::{sync::{atomic::{AtomicUsize, Ordering}, mpsc, Arc, Mutex}, thread};
use itertools::Itertools; use itertools::Itertools;
use crate::{ballot::Ballot, header::Header, stats::Stats, util::ScoreItem}; use crate::{ballot::Ballot, header::Header, stats::Stats, util::ScoreItem};
@ -11,7 +11,7 @@ pub enum Event {
#[derive(Debug)] #[derive(Debug)]
pub struct Counter { pub struct Counter {
pub header: Rc<Header>, pub header: Arc<Header>,
pub ballots: Vec<Ballot>, pub ballots: Vec<Ballot>,
pub stats: Stats, pub stats: Stats,
winners_left: usize, winners_left: usize,
@ -21,7 +21,7 @@ pub struct Counter {
impl Counter { impl Counter {
pub fn new<T: std::io::BufRead>(mut csv: quick_csv::Csv<T>, winners: usize) -> Result<Self, Box<dyn std::error::Error>> { pub fn new<T: std::io::BufRead>(mut csv: quick_csv::Csv<T>, winners: usize) -> Result<Self, Box<dyn std::error::Error>> {
let header = Rc::new(Header::parse(csv.next().ok_or("csv header missing")??, winners)?); let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")??, winners)?);
let mut ballots = Vec::new(); let mut ballots = Vec::new();
let mut stats = Stats::new(); let mut stats = Stats::new();
@ -47,17 +47,55 @@ impl Counter {
}) })
} }
fn count_primaries(&self) -> Vec<ScoreItem> { fn count_primaries(&self) -> Vec<ScoreItem> {
let mut scores = vec![0.0; self.enabled.len()]; let mut scores = Mutex::new(vec![0.0; self.enabled.len()]);
for ballot in self.ballots.iter() { let ballot_chunks = Mutex::new(self.ballots.chunks(1024));
if let Some(index) = ballot.get_primary_index(|id| self.enabled[id]) {
scores[index] += ballot.get_weight(); thread::scope(|s| {
let threads = std::thread::available_parallelism().unwrap().into();
for _ in 0..threads {
s.spawn(|| {
let mut l_scores = vec![0.0; self.enabled.len()];
while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } {
for ballot in ballots {
if let Some(index) = ballot.get_primary_index(&self.enabled) {
l_scores[index] += ballot.get_weight();
}
}
}
for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) {
*dst += src;
}
});
} }
} });
scores.into_iter().enumerate().filter_map(|(index, value)| match self.enabled[index] {
scores.get_mut().unwrap().iter().copied().enumerate().filter_map(|(index, value)| match self.enabled[index] {
true => Some(ScoreItem::new(index, value)), true => Some(ScoreItem::new(index, value)),
false => None, false => None,
}).collect_vec() }).collect_vec()
} }
fn apply_weights(&mut self, winners: &[ScoreItem]) {
let ballot_chunks = Mutex::new(self.ballots.chunks_mut(1024));
thread::scope(|s| {
let threads = std::thread::available_parallelism().unwrap().into();
for _ in 0..threads {
s.spawn(|| {
while let Some(ballots) = { ballot_chunks.lock().unwrap().next() } {
for ballot in ballots {
let index = match ballot.get_primary_index(&self.enabled) {
Some(v) => v,
None => continue,
};
if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() {
let surplus = winner.value - self.quota;
ballot.apply_weight(surplus / winner.value);
}
}
}
});
}
});
}
pub fn get_quota(&self) -> f64 { pub fn get_quota(&self) -> f64 {
self.quota self.quota
} }
@ -89,19 +127,9 @@ impl Iterator for Counter {
} }
winners.sort_by(|a,b| b.cmp(a)); winners.sort_by(|a,b| b.cmp(a));
self.apply_weights(&winners);
for ballot in self.ballots.iter_mut() {
let index = match ballot.get_primary_index(|id| self.enabled[id]) {
Some(v) => v,
None => continue,
};
if let Some(winner) = winners.iter().filter(|&v| v.index == index).next() {
let surplus = winner.value - self.quota;
ballot.apply_weight(surplus / winner.value);
}
}
self.winners_left -= winners.len(); self.winners_left -= winners.len();
for winner in winners.iter() { for winner in winners.iter() {
self.enabled[winner.index] = false; self.enabled[winner.index] = false;
} }