[Rust/2022/16] Rewrite and improve solution

This commit is contained in:
Felix Bargfeldt 2023-10-22 17:59:15 +02:00
parent 8d03457bfb
commit 00682222e0
Signed by: Defelo
GPG key ID: 2A05272471204DD3
2 changed files with 360 additions and 116 deletions

View file

@ -1,146 +1,317 @@
#![feature(test)]
use regex::Regex;
use rustc_hash::FxHashMap;
use std::{
cmp::{Ordering, Reverse},
collections::BinaryHeap,
};
type Input = Vec<Valve>;
use aoc::iter_ext::IterExt;
use itertools::Itertools;
use rustc_hash::{FxHashMap, FxHashSet};
#[derive(Debug)]
struct Input {
valves: Vec<Valve>,
start: usize,
}
#[derive(Debug)]
struct Valve {
rate: u32,
tunnels: Vec<usize>,
}
fn setup(input: &str) -> Input {
let regex = Regex::new(r"^Valve ([A-Z]+) .* rate=(\d+); .* valves? ([A-Z, ]+)$").unwrap();
let mut names = FxHashMap::default();
names.insert("AA", 0);
let mut name = |n| {
let cnt = names.len();
*names.entry(n).or_insert(cnt)
};
let mut valves: FxHashMap<usize, Valve> = input
.trim()
let names = input
.lines()
.enumerate()
.map(|(i, line)| (line.split_whitespace().nth(1).unwrap(), i))
.collect::<FxHashMap<_, _>>();
let valves = input
.lines()
.map(|line| {
let caps = regex.captures(line).unwrap();
let v = name(caps.get(1).unwrap().as_str());
let rate = caps[2].parse().unwrap();
let tunnels = caps
.get(3)
let rate = line
.split('=')
.nth(1)
.unwrap()
.as_str()
.split(", ")
.map(&mut name)
.split(';')
.next()
.unwrap()
.parse()
.unwrap();
let tunnels = line
.split_whitespace()
.skip(9)
.map(|v| *names.get(v.trim_matches(',')).unwrap())
.collect();
(v, Valve { rate, tunnels })
Valve { rate, tunnels }
})
.collect();
(0..valves.len())
.map(|x| valves.remove(&x).unwrap())
.collect()
Input {
valves,
start: names["AA"],
}
}
struct Solver<'a> {
valves: &'a Input,
dist: Vec<Vec<u32>>,
dp: FxHashMap<(usize, u32, u64), u32>,
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct State {
human: SubState,
elephant: SubState,
released: u32,
valves: ValveState,
}
impl<'a> Solver<'a> {
fn new(valves: &'a Input) -> Self {
let mut dist = vec![vec![u32::MAX; valves.len()]; valves.len()];
for (i, v) in valves.iter().enumerate() {
dist[i][i] = 0;
v.tunnels.iter().for_each(|&j| dist[i][j] = 1);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct SubState {
time: u32,
position: usize,
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
self.released.cmp(&other.released)
}
}
impl SubState {
fn next(
self,
graph: &Graph,
valves: ValveState,
) -> impl Iterator<Item = (SubState, u32, ValveState)> + '_ {
graph.closed_valves(valves).filter_map(move |(i, rate)| {
let time = self
.time
.checked_sub(graph.distances[self.position][i] + 2)?
+ 1;
Some((
SubState { time, position: i },
rate * time,
valves.open_valve(i),
))
})
}
}
impl State {
fn init(position: usize, human_time: u32, elephant_time: u32) -> Self {
Self {
human: SubState {
time: human_time,
position,
},
elephant: SubState {
time: elephant_time,
position,
},
released: 0,
valves: ValveState::default(),
}
for k in 0..valves.len() {
for i in 0..valves.len() {
for j in 0..valves.len() {
if let Some(d) = dist[i][k].checked_add(dist[k][j]) {
dist[i][j] = dist[i][j].min(d);
}
}
fn update(
self,
human: Option<SubState>,
elephant: Option<SubState>,
released: u32,
valves: ValveState,
) -> Self {
Self {
human: human.unwrap_or(self.human),
elephant: elephant.unwrap_or(self.elephant),
released: self.released + released,
valves,
}
}
fn next(self, graph: &Graph) -> impl Iterator<Item = State> + '_ {
self.elephant
.next(graph, self.valves)
.flat_map(move |(new_elephant, elephant_released, valves)| {
self.human
.next(graph, valves)
.map(move |(new_human, human_released, valves)| {
self.update(
Some(new_human),
Some(new_elephant),
human_released + elephant_released,
valves,
)
})
})
.chain_if_empty(self.human.next(graph, self.valves).map(
move |(new_human, released, valves)| {
self.update(Some(new_human), None, released, valves)
},
))
.chain_if_empty(self.elephant.next(graph, self.valves).map(
move |(new_elephant, released, valves)| {
self.update(None, Some(new_elephant), released, valves)
},
))
}
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
struct ValveState(u64);
impl ValveState {
fn is_valve_open(self, valve: usize) -> bool {
self.0 & (1 << valve) != 0
}
fn open_valve(self, valve: usize) -> Self {
Self(self.0 | (1 << valve))
}
}
struct Graph {
rates: Vec<u32>,
distances: Vec<Vec<u32>>,
start: usize,
}
impl Graph {
fn from_input(input: &Input) -> Self {
let distances = floyd_warshall(&input.valves);
let rates = input.valves.iter().map(|v| v.rate).collect();
Self {
rates,
distances,
start: input.start,
}
}
fn remove_irrelevant_nodes(mut self) -> Self {
let relevant_valves = self
.rates
.iter()
.enumerate()
.filter(|&(i, &rate)| rate > 0 || i == self.start)
.map(|(i, _)| i)
.collect_vec();
self.distances = relevant_valves
.iter()
.map(|&i| {
relevant_valves
.iter()
.map(|&j| self.distances[i][j])
.collect()
})
.collect();
self.start = relevant_valves
.iter()
.find_position(|&&i| i == self.start)
.unwrap()
.0;
self.rates = relevant_valves.into_iter().map(|i| self.rates[i]).collect();
self
}
fn sort_valves(mut self) -> Self {
let indices = (0..self.rates.len())
.sorted_unstable_by_key(|&i| Reverse(self.rates[i]))
.collect_vec();
self.distances = indices
.iter()
.map(|&i| indices.iter().map(|&j| self.distances[i][j]).collect())
.collect();
self.start = indices
.iter()
.find_position(|&&i| i == self.start)
.unwrap()
.0;
self.rates = indices.into_iter().map(|i| self.rates[i]).collect();
self
}
fn closed_valves(&self, valves: ValveState) -> impl Iterator<Item = (usize, u32)> + '_ {
self.rates
.iter()
.copied()
.enumerate()
.filter(move |&(i, rate)| !valves.is_valve_open(i) && rate > 0)
}
}
fn floyd_warshall(valves: &[Valve]) -> Vec<Vec<u32>> {
let mut dist = vec![vec![u32::MAX; valves.len()]; valves.len()];
for (i, valve) in valves.iter().enumerate() {
for &j in &valve.tunnels {
dist[i][j] = 1;
}
dist[i][i] = 0;
}
for k in 0..valves.len() {
for i in 0..valves.len() {
let di = dist[i][k];
if di == u32::MAX {
continue;
}
for j in 0..valves.len() {
let dj = dist[k][j];
if dj == u32::MAX {
continue;
}
dist[i][j] = dist[i][j].min(di + dj);
}
}
}
dist
}
Self {
valves,
dist,
dp: FxHashMap::default(),
fn solve(input: &Input, human_time: u32, elephant_time: u32) -> u32 {
let graph = Graph::from_input(input)
.remove_irrelevant_nodes()
.sort_valves();
let mut queue = BinaryHeap::from([State::init(graph.start, human_time, elephant_time)]);
let mut seen = FxHashSet::default();
let mut out = 0;
while let Some(state) = queue.pop() {
out = out.max(state.released);
for next in state.next(&graph) {
let closest = graph
.closed_valves(next.valves)
.map(|(i, _)| graph.distances[next.human.position][i])
.min()
.unwrap_or(0);
let time = (next.human.time + 1).saturating_sub(closest);
let max = next.released
+ graph
.closed_valves(next.valves)
.take(time as usize / 2)
.enumerate()
.map(|(i, (_, v))| v * (time - ((i as u32 + 1) * 2)))
.sum::<u32>();
if max < out {
continue;
}
if seen.insert(next) {
queue.push(next);
}
}
}
fn solve(&mut self, p: usize, time: u32, closed: u64) -> u32 {
let key = (p, time, closed);
if let Some(&result) = self.dp.get(&key) {
return result;
}
let result = (0..self.valves.len())
.filter_map(|q| {
if closed & 1 << q == 0 {
return None;
}
if let Some(t) = time.checked_sub(self.dist[p][q] + 1) {
Some(self.solve(q, t, closed & !(1 << q)) + self.valves[q].rate * t)
} else {
None
}
})
.max()
.unwrap_or(0);
self.dp.insert(key, result);
result
}
out
}
fn part1(input: &Input) -> u32 {
Solver::new(input).solve(
0,
30,
input
.iter()
.enumerate()
.filter(|(_, x)| x.rate > 0)
.fold(0, |acc, (i, _)| acc | 1 << i),
)
solve(input, 30, 0)
}
fn part2(input: &Input) -> u32 {
let mut solver = Solver::new(input);
let valves = input
.iter()
.enumerate()
.filter(|(_, x)| x.rate > 0)
.map(|(i, _)| i)
.collect::<Vec<_>>();
(0u64..1 << (valves.len() - 1))
.filter(|&s| valves.len().abs_diff(s.count_ones() as usize * 2) <= 1)
.map(|s| {
let a = solver.solve(
0,
26,
valves
.iter()
.enumerate()
.filter(|&(i, _)| s & 1 << i != 0)
.fold(0, |acc, (_, j)| acc | 1 << j),
);
let b = solver.solve(
0,
26,
valves
.iter()
.enumerate()
.filter(|&(i, _)| s & 1 << i == 0)
.fold(0, |acc, (_, j)| acc | 1 << j),
);
a + b
})
.max()
.unwrap()
solve(input, 26, 26)
}
aoc::main!(2022, 16, ex: 1);

View file

@ -1,13 +1,16 @@
pub trait IterExt<I>
where
I: Iterator,
{
fn take_while_inclusive<P>(self, predicate: P) -> TakeWhileInclusive<I, P>
pub trait IterExt: Iterator {
fn take_while_inclusive<P>(self, predicate: P) -> TakeWhileInclusive<Self, P>
where
P: FnMut(&I::Item) -> bool;
Self: Sized,
P: FnMut(&Self::Item) -> bool;
fn chain_if_empty<U>(self, other: U) -> ChainIfEmpty<Self, U::IntoIter>
where
Self: Sized,
U: IntoIterator<Item = Self::Item>;
}
impl<I> IterExt<I> for I
impl<I> IterExt for I
where
I: Iterator,
{
@ -17,6 +20,17 @@ where
{
TakeWhileInclusive::new(self, predicate)
}
fn chain_if_empty<U>(self, other: U) -> ChainIfEmpty<I, U::IntoIter>
where
U: IntoIterator<Item = Self::Item>,
{
ChainIfEmpty {
iter: self,
other: other.into_iter(),
state: State::Unknown,
}
}
}
pub struct TakeWhileInclusive<I, P> {
@ -55,6 +69,43 @@ where
}
}
enum State {
Empty,
NotEmpty,
Unknown,
}
pub struct ChainIfEmpty<I, U> {
iter: I,
other: U,
state: State,
}
impl<I, U> Iterator for ChainIfEmpty<I, U>
where
I: Iterator,
U: Iterator<Item = I::Item>,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
match self.state {
State::Empty => self.other.next(),
State::NotEmpty => self.iter.next(),
State::Unknown => match self.iter.next() {
Some(x) => {
self.state = State::NotEmpty;
Some(x)
}
None => {
self.state = State::Empty;
self.other.next()
}
},
}
}
}
#[cfg(test)]
mod tests_take_while_inclusive {
use super::*;
@ -79,3 +130,25 @@ mod tests_take_while_inclusive {
test!(always_true, [1, 2, 3, 4], |_| true, vec![1, 2, 3, 4]);
test!(always_false, [1, 2, 3, 4], |_| false, vec![1]);
}
#[cfg(test)]
mod tests_chain_if_empty {
use super::*;
macro_rules! test {
($name:ident, $inp1:expr, $inp2:expr, $exp: expr) => {
#[test]
fn $name() {
assert_eq!(
$inp1.into_iter().chain_if_empty($inp2).collect::<Vec<_>>(),
$exp
);
}
};
}
test!(both_empty, Vec::<i32>::new(), [], vec![]);
test!(first_empty, [], [4, 5, 6], vec![4, 5, 6]);
test!(second_empty, [1, 2, 3], [], vec![1, 2, 3]);
test!(both_nonempty, [1, 2, 3], [4, 5, 6], vec![1, 2, 3]);
}