optimisations

This commit is contained in:
Jay Robson 2025-05-12 14:44:22 +10:00
parent beeea3597e
commit c9c106edaf
3 changed files with 74 additions and 52 deletions

View File

@ -12,7 +12,7 @@ pub struct Ballot {
const CANDIDATE_MIN: usize = 6;
impl Ballot {
pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result<Ballot, Box<dyn std::error::Error>> {
pub fn parse(mut row: csv::reader::Row, header: &Header, stats: &mut Stats) -> Result<Ballot, Box<dyn std::error::Error + Sync + Send>> {
let mut cols = row.by_ref().skip(6);
let place_filter = |(index, place): (usize, String)| match place.parse::<i32>() {

View File

@ -1,7 +1,6 @@
use std::{sync::{Arc, Mutex}, thread};
use itertools::Itertools;
use std::sync::{atomic::{AtomicUsize, Ordering}, Arc, Mutex};
use crate::{ballot::Ballot, csv, header::Header, stats::Stats, util::{self, ScoreItem}};
use itertools::Itertools;
#[derive(Debug)]
pub enum Event {
@ -9,7 +8,6 @@ pub enum Event {
Win(Vec<ScoreItem>),
}
#[derive(Debug)]
pub struct Counter {
pub header: Arc<Header>,
pub ballots: Vec<Ballot>,
@ -24,37 +22,39 @@ const CHUNK_SIZE: usize = 1024;
impl Counter {
pub fn new(mut csv: csv::reader::CsvReader, winners: usize) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let header = Arc::new(Header::parse(csv.next().ok_or("csv header missing")?, winners)?);
let ballots = Mutex::new(Vec::new());
let stats = Mutex::new(Stats::new());
let csv = Mutex::new(csv);
if winners > header.candidates.len() {
return Err("winners can't be smaller than the candidates list".into());
}
util::thread::spawn_all_with_result(|| -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let (ballots, stats) = util::thread::spawn_all_with_callback(|| -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
let mut l_stats = Stats::new();
let mut l_ballots = Vec::new();
loop {
let rows = csv.lock().unwrap().by_ref().take(CHUNK_SIZE).collect_vec();
if rows.len() == 0 {
break;
}
for row in rows {
l_ballots.push(Ballot::parse(row, &header, &mut l_stats).unwrap());
l_ballots.push(Ballot::parse(row, &header, &mut l_stats)?);
}
}
Ok((l_stats, l_ballots))
}, |it| -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
let mut ballots = Vec::new();
let mut stats = Stats::new();
for value in it {
let (l_stats, mut l_ballots) = value?;
ballots.append(&mut l_ballots);
stats.add(&l_stats);
}
ballots.lock().unwrap().append(&mut l_ballots);
}
{stats.lock().unwrap().add(&l_stats)};
Ok(())
Ok((ballots, stats))
})?;
let ballots = ballots.into_inner().unwrap();
let stats = stats.into_inner().unwrap();
let enabled = vec![true; header.candidates.len()];
let quota = (ballots.len() as f64) / (header.winners as f64 + 1.0) + 1.0;
let winners_left = header.winners;
@ -68,21 +68,23 @@ impl Counter {
quota,
})
}
fn count_primaries(&self) -> Vec<ScoreItem> {
let mut scores = Mutex::new(vec![0.0; self.enabled.len()]);
let ballot_chunks = Mutex::new(self.ballots.chunks(CHUNK_SIZE));
fn count_primaries(&mut self) -> Vec<ScoreItem> {
let mut scores = vec![0.0; self.enabled.len()];
let ballots_at = AtomicUsize::new(0);
util::thread::spawn_all(|| {
util::thread::spawn_all_with_callback(|| {
let mut l_scores = vec![0.0; self.enabled.len()];
for ballot in util::mutex::chunk_iter(&ballot_chunks) {
while let Some(ballots) = { self.ballots.get(ballots_at.fetch_add(CHUNK_SIZE, Ordering::Relaxed)..) } {
for ballot in ballots.iter().take(CHUNK_SIZE) {
ballot.count_primary(&mut l_scores, &self.enabled);
}
for (dst, src) in scores.lock().unwrap().iter_mut().zip(l_scores.into_iter()) {
*dst += src;
}
l_scores
}, |it| for l_scores in it {
scores.iter_mut().zip(l_scores).for_each(|(dst, src)| *dst += src);
});
scores.get_mut().unwrap().iter().copied().enumerate().filter_map(|(index, value)| match self.enabled[index] {
scores.into_iter().enumerate().filter_map(|(index, value)| match self.enabled[index] {
true => Some(ScoreItem::new(index, value)),
false => None,
}).collect_vec()

View File

@ -1,31 +1,51 @@
use std::sync::Mutex;
use std::sync::mpsc::{self, Receiver};
pub struct Iter<T> {
rx: Receiver<T>,
left: usize,
}
impl<T> Iterator for Iter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.left > 0 {
self.left -= 1;
Some(self.rx.recv().unwrap())
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.left, Some(self.left))
}
}
pub fn thread_count() -> usize {
std::thread::available_parallelism().unwrap().into()
}
pub fn spawn_all(func: impl Fn() + Sync) {
let num_threads = std::thread::available_parallelism().unwrap().into();
#[inline]
pub fn spawn_all<F>(f: F) where F: Fn() + Sync {
std::thread::scope(|s| {
for _ in 0..std::thread::available_parallelism().unwrap().into() {
s.spawn(|| f());
for _ in 0..num_threads {
s.spawn(|| func());
}
});
}
#[inline]
pub fn spawn_all_with_result<F, T>(f: F) -> Result<(), T>
where
F: Fn() -> Result<(), T> + Sync,
T: Send
{
let res = Mutex::new(None);
pub fn spawn_all_with_callback<T: Send, R>(func: impl Fn() -> T + Sync, callback: impl FnOnce(Iter<T>) -> R) -> R {
let (tx, rx) = mpsc::channel();
let num_threads = std::thread::available_parallelism().unwrap().into();
spawn_all(|| {
if let Err(err) = f() {
*res.lock().unwrap() = Some(err);
}
});
match res.into_inner().unwrap() {
Some(err) => Err(err),
None => Ok(()),
std::thread::scope(|s| {
for _ in 0..num_threads {
s.spawn(|| tx.send(func()).unwrap());
}
callback(Iter {
left: num_threads,
rx,
})
})
}