From c9c106edaffb9c08b19f697762d1e36596f9f9b6 Mon Sep 17 00:00:00 2001 From: Jay Robson Date: Mon, 12 May 2025 14:44:22 +1000 Subject: [PATCH] optimisations --- src/ballot.rs | 2 +- src/counter.rs | 56 ++++++++++++++++++++------------------ src/util/thread.rs | 68 ++++++++++++++++++++++++++++++---------------- 3 files changed, 74 insertions(+), 52 deletions(-) diff --git a/src/ballot.rs b/src/ballot.rs index e58f7bf..7dbd02a 100644 --- a/src/ballot.rs +++ b/src/ballot.rs @@ -12,7 +12,7 @@ pub struct Ballot { const CANDIDATE_MIN: usize = 6; impl Ballot { - pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result> { + pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result> { let mut cols = row.by_ref().skip(6); let place_filter = |(index, place): (usize, String)| match place.parse::() { diff --git a/src/counter.rs b/src/counter.rs index 729a12c..d8932ba 100644 --- a/src/counter.rs +++ b/src/counter.rs @@ -1,7 +1,6 @@ -use std::{sync::{Arc, Mutex}, thread}; -use itertools::Itertools; - +use std::sync::{atomic::{AtomicUsize, Ordering}, Arc, Mutex}; use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::{self, ScoreItem}}; +use itertools::Itertools; #[derive(Debug)] pub enum Event { @@ -9,7 +8,6 @@ pub enum Event { Win(Vec), } -#[derive(Debug)] pub struct Counter { pub header: Arc
, pub ballots: Vec, @@ -24,36 +22,38 @@ const CHUNK_SIZE: usize = 1024; impl Counter { pub fn new(mut csv: csv::reader::CsvReader, winners: usize) -> Result> { let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")?, winners)?); - let ballots = Mutex::new(Vec::new()); - let stats = Mutex::new(Stats::new()); let csv = Mutex::new(csv); if winners > header.candidates.len() { return Err("winners can't be smaller than the candidates list".into()); } - util::thread::spawn_all_with_result(|| -> Result<(), Box> { + let (ballots, stats) = util::thread::spawn_all_with_callback(|| -> Result<_, Box> { let mut l_stats = Stats::new(); let mut l_ballots = Vec::new(); loop { let rows = csv.lock().unwrap().by_ref().take(CHUNK_SIZE).collect_vec(); - if rows.len() == 0 { break; } - for row in rows { - l_ballots.push(Ballot::parse(row, &header, &mut l_stats).unwrap()); + l_ballots.push(Ballot::parse(row, &header, &mut l_stats)?); } - - ballots.lock().unwrap().append(&mut l_ballots); } - {stats.lock().unwrap().add(&l_stats)}; - Ok(()) - })?; + Ok((l_stats, l_ballots)) + }, |it| -> Result<_, Box> { - let ballots = ballots.into_inner().unwrap(); - let stats = stats.into_inner().unwrap(); + let mut ballots = Vec::new(); + let mut stats = Stats::new(); + + for value in it { + let (l_stats, mut l_ballots) = value?; + ballots.append(&mut l_ballots); + stats.add(&l_stats); + } + + Ok((ballots, stats)) + })?; let enabled = vec![true; header.candidates.len()]; let quota = (ballots.len() as f64) / (header.winners as f64 + 1.0) + 1.0; @@ -68,21 +68,23 @@ impl Counter { quota, }) } - fn count_primaries(&self) -> Vec { - let mut scores = Mutex::new(vec![0.0; self.enabled.len()]); - let ballot_chunks = Mutex::new(self.ballots.chunks(CHUNK_SIZE)); + fn count_primaries(&mut self) -> Vec { + let mut scores = vec![0.0; self.enabled.len()]; + let ballots_at = AtomicUsize::new(0); - util::thread::spawn_all(|| { + util::thread::spawn_all_with_callback(|| { let mut l_scores = vec![0.0; self.enabled.len()]; - for ballot in util::mutex::chunk_iter(&ballot_chunks) { - ballot.count_primary(&mut l_scores, &self.enabled); - } - for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) { - *dst += src; + while let Some(ballots) = { self.ballots.get(ballots_at.fetch_add(CHUNK_SIZE, Ordering::Relaxed)..) } { + for ballot in ballots.iter().take(CHUNK_SIZE) { + ballot.count_primary(&mut l_scores, &self.enabled); + } } + l_scores + }, |it| for l_scores in it { + scores.iter_mut().zip(l_scores).for_each(|(dst, src)| *dst += src); }); - scores.get_mut().unwrap().iter().copied().enumerate().filter_map(|(index, value)| match self.enabled[index] { + scores.into_iter().enumerate().filter_map(|(index, value)| match self.enabled[index] { true => Some(ScoreItem::new(index, value)), false => None, }).collect_vec() diff --git a/src/util/thread.rs b/src/util/thread.rs index 14fcfa8..6db4d7c 100644 --- a/src/util/thread.rs +++ b/src/util/thread.rs @@ -1,31 +1,51 @@ -use std::sync::Mutex; +use std::sync::mpsc::{self, Receiver}; -#[inline] -pub fn spawn_all(f: F) where F: Fn() + Sync { - std::thread::scope(|s| { - for _ in 0..std::thread::available_parallelism().unwrap().into() { - s.spawn(|| f()); - } - }); +pub struct Iter { + rx: Receiver, + left: usize, } -#[inline] -pub fn spawn_all_with_result(f: F) -> Result<(), T> - where - F: Fn() -> Result<(), T> + Sync, - T: Send -{ - let res = Mutex::new(None); - - spawn_all(|| { - if let Err(err) = f() { - *res.lock().unwrap() = Some(err); +impl Iterator for Iter { + type Item = T; + fn next(&mut self) -> Option { + if self.left > 0 { + self.left -= 1; + Some(self.rx.recv().unwrap()) + } else { + None } - }); - - match res.into_inner().unwrap() { - Some(err) => Err(err), - None => Ok(()), + } + fn size_hint(&self) -> (usize, Option) { + (self.left, Some(self.left)) } } +pub fn thread_count() -> usize { + std::thread::available_parallelism().unwrap().into() +} + +pub fn spawn_all(func: impl Fn() + Sync) { + let num_threads = std::thread::available_parallelism().unwrap().into(); + + std::thread::scope(|s| { + for _ in 0..num_threads { + s.spawn(|| func()); + } + }); +} + +pub fn spawn_all_with_callback(func: impl Fn() -> T + Sync, callback: impl FnOnce(Iter) -> R) -> R { + let (tx, rx) = mpsc::channel(); + let num_threads = std::thread::available_parallelism().unwrap().into(); + + std::thread::scope(|s| { + for _ in 0..num_threads { + s.spawn(|| tx.send(func()).unwrap()); + } + callback(Iter { + left: num_threads, + rx, + }) + }) +} +