optimisations

This commit is contained in:
Jay Robson 2025-05-12 14:44:22 +10:00
parent beeea3597e
commit c9c106edaf
3 changed files with 74 additions and 52 deletions

View File

@ -12,7 +12,7 @@ pub struct Ballot {
const CANDIDATE_MIN: usize = 6; const CANDIDATE_MIN: usize = 6;
impl Ballot { impl Ballot {
pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result<Ballot, Box<dyn std::error::Error>> { pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result<Ballot, Box<dyn std::error::Error + Sync + Send>> {
let mut cols = row.by_ref().skip(6); let mut cols = row.by_ref().skip(6);
let place_filter = |(index, place): (usize, String)| match place.parse::<i32>() { let place_filter = |(index, place): (usize, String)| match place.parse::<i32>() {

View File

@ -1,7 +1,6 @@
use std::{sync::{Arc, Mutex}, thread}; use std::sync::{atomic::{AtomicUsize, Ordering}, Arc, Mutex};
use itertools::Itertools;
use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::{self, ScoreItem}}; use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::{self, ScoreItem}};
use itertools::Itertools;
#[derive(Debug)] #[derive(Debug)]
pub enum Event { pub enum Event {
@ -9,7 +8,6 @@ pub enum Event {
Win(Vec<ScoreItem>), Win(Vec<ScoreItem>),
} }
#[derive(Debug)]
pub struct Counter { pub struct Counter {
pub header: Arc<Header>, pub header: Arc<Header>,
pub ballots: Vec<Ballot>, pub ballots: Vec<Ballot>,
@ -24,36 +22,38 @@ const CHUNK_SIZE: usize = 1024;
impl Counter { impl Counter {
pub fn new(mut csv: csv::reader::CsvReader, winners: usize) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> { pub fn new(mut csv: csv::reader::CsvReader, winners: usize) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")?, winners)?); let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")?, winners)?);
let ballots = Mutex::new(Vec::new());
let stats = Mutex::new(Stats::new());
let csv = Mutex::new(csv); let csv = Mutex::new(csv);
if winners > header.candidates.len() { if winners > header.candidates.len() {
return Err("winners can't be smaller than the candidates list".into()); return Err("winners can't be smaller than the candidates list".into());
} }
util::thread::spawn_all_with_result(|| -> Result<(), Box<dyn std::error::Error + Sync + Send>> { let (ballots, stats) = util::thread::spawn_all_with_callback(|| -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
let mut l_stats = Stats::new(); let mut l_stats = Stats::new();
let mut l_ballots = Vec::new(); let mut l_ballots = Vec::new();
loop { loop {
let rows = csv.lock().unwrap().by_ref().take(CHUNK_SIZE).collect_vec(); let rows = csv.lock().unwrap().by_ref().take(CHUNK_SIZE).collect_vec();
if rows.len() == 0 { if rows.len() == 0 {
break; break;
} }
for row in rows { for row in rows {
l_ballots.push(Ballot::parse(row, &header, &mut l_stats).unwrap()); l_ballots.push(Ballot::parse(row, &header, &mut l_stats)?);
} }
ballots.lock().unwrap().append(&mut l_ballots);
} }
{stats.lock().unwrap().add(&l_stats)}; Ok((l_stats, l_ballots))
Ok(()) }, |it| -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
})?;
let ballots = ballots.into_inner().unwrap(); let mut ballots = Vec::new();
let stats = stats.into_inner().unwrap(); let mut stats = Stats::new();
for value in it {
let (l_stats, mut l_ballots) = value?;
ballots.append(&mut l_ballots);
stats.add(&l_stats);
}
Ok((ballots, stats))
})?;
let enabled = vec![true; header.candidates.len()]; let enabled = vec![true; header.candidates.len()];
let quota = (ballots.len() as f64) / (header.winners as f64 + 1.0) + 1.0; let quota = (ballots.len() as f64) / (header.winners as f64 + 1.0) + 1.0;
@ -68,21 +68,23 @@ impl Counter {
quota, quota,
}) })
} }
fn count_primaries(&self) -> Vec<ScoreItem> { fn count_primaries(&mut self) -> Vec<ScoreItem> {
let mut scores = Mutex::new(vec![0.0; self.enabled.len()]); let mut scores = vec![0.0; self.enabled.len()];
let ballot_chunks = Mutex::new(self.ballots.chunks(CHUNK_SIZE)); let ballots_at = AtomicUsize::new(0);
util::thread::spawn_all(|| { util::thread::spawn_all_with_callback(|| {
let mut l_scores = vec![0.0; self.enabled.len()]; let mut l_scores = vec![0.0; self.enabled.len()];
for ballot in util::mutex::chunk_iter(&ballot_chunks) { while let Some(ballots) = { self.ballots.get(ballots_at.fetch_add(CHUNK_SIZE, Ordering::Relaxed)..) } {
ballot.count_primary(&mut l_scores, &self.enabled); for ballot in ballots.iter().take(CHUNK_SIZE) {
} ballot.count_primary(&mut l_scores, &self.enabled);
for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) { }
*dst += src;
} }
l_scores
}, |it| for l_scores in it {
scores.iter_mut().zip(l_scores).for_each(|(dst, src)| *dst += src);
}); });
scores.get_mut().unwrap().iter().copied().enumerate().filter_map(|(index, value)| match self.enabled[index] { scores.into_iter().enumerate().filter_map(|(index, value)| match self.enabled[index] {
true => Some(ScoreItem::new(index, value)), true => Some(ScoreItem::new(index, value)),
false => None, false => None,
}).collect_vec() }).collect_vec()

View File

@ -1,31 +1,51 @@
use std::sync::Mutex; use std::sync::mpsc::{self, Receiver};
#[inline] pub struct Iter<T> {
pub fn spawn_all<F>(f: F) where F: Fn() + Sync { rx: Receiver<T>,
std::thread::scope(|s| { left: usize,
for _ in 0..std::thread::available_parallelism().unwrap().into() {
s.spawn(|| f());
}
});
} }
#[inline] impl<T> Iterator for Iter<T> {
pub fn spawn_all_with_result<F, T>(f: F) -> Result<(), T> type Item = T;
where fn next(&mut self) -> Option<Self::Item> {
F: Fn() -> Result<(), T> + Sync, if self.left > 0 {
T: Send self.left -= 1;
{ Some(self.rx.recv().unwrap())
let res = Mutex::new(None); } else {
None
spawn_all(|| {
if let Err(err) = f() {
*res.lock().unwrap() = Some(err);
} }
}); }
fn size_hint(&self) -> (usize, Option<usize>) {
match res.into_inner().unwrap() { (self.left, Some(self.left))
Some(err) => Err(err),
None => Ok(()),
} }
} }
pub fn thread_count() -> usize {
std::thread::available_parallelism().unwrap().into()
}
pub fn spawn_all(func: impl Fn() + Sync) {
let num_threads = std::thread::available_parallelism().unwrap().into();
std::thread::scope(|s| {
for _ in 0..num_threads {
s.spawn(|| func());
}
});
}
pub fn spawn_all_with_callback<T: Send, R>(func: impl Fn() -> T + Sync, callback: impl FnOnce(Iter<T>) -> R) -> R {
let (tx, rx) = mpsc::channel();
let num_threads = std::thread::available_parallelism().unwrap().into();
std::thread::scope(|s| {
for _ in 0..num_threads {
s.spawn(|| tx.send(func()).unwrap());
}
callback(Iter {
left: num_threads,
rx,
})
})
}