[Rust/2022/16] Rewrite and improve solution
This commit is contained in:
parent
8d03457bfb
commit
00682222e0
2 changed files with 360 additions and 116 deletions
389
Rust/2022/16.rs
389
Rust/2022/16.rs
|
@ -1,146 +1,317 @@
|
|||
#![feature(test)]
|
||||
|
||||
use regex::Regex;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::{
|
||||
cmp::{Ordering, Reverse},
|
||||
collections::BinaryHeap,
|
||||
};
|
||||
|
||||
type Input = Vec<Valve>;
|
||||
use aoc::iter_ext::IterExt;
|
||||
use itertools::Itertools;
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Input {
|
||||
valves: Vec<Valve>,
|
||||
start: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Valve {
|
||||
rate: u32,
|
||||
tunnels: Vec<usize>,
|
||||
}
|
||||
|
||||
fn setup(input: &str) -> Input {
|
||||
let regex = Regex::new(r"^Valve ([A-Z]+) .* rate=(\d+); .* valves? ([A-Z, ]+)$").unwrap();
|
||||
let mut names = FxHashMap::default();
|
||||
names.insert("AA", 0);
|
||||
let mut name = |n| {
|
||||
let cnt = names.len();
|
||||
*names.entry(n).or_insert(cnt)
|
||||
};
|
||||
|
||||
let mut valves: FxHashMap<usize, Valve> = input
|
||||
.trim()
|
||||
let names = input
|
||||
.lines()
|
||||
.enumerate()
|
||||
.map(|(i, line)| (line.split_whitespace().nth(1).unwrap(), i))
|
||||
.collect::<FxHashMap<_, _>>();
|
||||
let valves = input
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let caps = regex.captures(line).unwrap();
|
||||
let v = name(caps.get(1).unwrap().as_str());
|
||||
let rate = caps[2].parse().unwrap();
|
||||
let tunnels = caps
|
||||
.get(3)
|
||||
let rate = line
|
||||
.split('=')
|
||||
.nth(1)
|
||||
.unwrap()
|
||||
.as_str()
|
||||
.split(", ")
|
||||
.map(&mut name)
|
||||
.split(';')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap();
|
||||
let tunnels = line
|
||||
.split_whitespace()
|
||||
.skip(9)
|
||||
.map(|v| *names.get(v.trim_matches(',')).unwrap())
|
||||
.collect();
|
||||
(v, Valve { rate, tunnels })
|
||||
Valve { rate, tunnels }
|
||||
})
|
||||
.collect();
|
||||
|
||||
(0..valves.len())
|
||||
.map(|x| valves.remove(&x).unwrap())
|
||||
.collect()
|
||||
Input {
|
||||
valves,
|
||||
start: names["AA"],
|
||||
}
|
||||
}
|
||||
|
||||
struct Solver<'a> {
|
||||
valves: &'a Input,
|
||||
dist: Vec<Vec<u32>>,
|
||||
dp: FxHashMap<(usize, u32, u64), u32>,
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct State {
|
||||
human: SubState,
|
||||
elephant: SubState,
|
||||
released: u32,
|
||||
valves: ValveState,
|
||||
}
|
||||
|
||||
impl<'a> Solver<'a> {
|
||||
fn new(valves: &'a Input) -> Self {
|
||||
let mut dist = vec![vec![u32::MAX; valves.len()]; valves.len()];
|
||||
for (i, v) in valves.iter().enumerate() {
|
||||
dist[i][i] = 0;
|
||||
v.tunnels.iter().for_each(|&j| dist[i][j] = 1);
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct SubState {
|
||||
time: u32,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl PartialOrd for State {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for State {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.released.cmp(&other.released)
|
||||
}
|
||||
}
|
||||
|
||||
impl SubState {
|
||||
fn next(
|
||||
self,
|
||||
graph: &Graph,
|
||||
valves: ValveState,
|
||||
) -> impl Iterator<Item = (SubState, u32, ValveState)> + '_ {
|
||||
graph.closed_valves(valves).filter_map(move |(i, rate)| {
|
||||
let time = self
|
||||
.time
|
||||
.checked_sub(graph.distances[self.position][i] + 2)?
|
||||
+ 1;
|
||||
Some((
|
||||
SubState { time, position: i },
|
||||
rate * time,
|
||||
valves.open_valve(i),
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn init(position: usize, human_time: u32, elephant_time: u32) -> Self {
|
||||
Self {
|
||||
human: SubState {
|
||||
time: human_time,
|
||||
position,
|
||||
},
|
||||
elephant: SubState {
|
||||
time: elephant_time,
|
||||
position,
|
||||
},
|
||||
released: 0,
|
||||
valves: ValveState::default(),
|
||||
}
|
||||
for k in 0..valves.len() {
|
||||
for i in 0..valves.len() {
|
||||
for j in 0..valves.len() {
|
||||
if let Some(d) = dist[i][k].checked_add(dist[k][j]) {
|
||||
dist[i][j] = dist[i][j].min(d);
|
||||
}
|
||||
}
|
||||
|
||||
fn update(
|
||||
self,
|
||||
human: Option<SubState>,
|
||||
elephant: Option<SubState>,
|
||||
released: u32,
|
||||
valves: ValveState,
|
||||
) -> Self {
|
||||
Self {
|
||||
human: human.unwrap_or(self.human),
|
||||
elephant: elephant.unwrap_or(self.elephant),
|
||||
released: self.released + released,
|
||||
valves,
|
||||
}
|
||||
}
|
||||
|
||||
fn next(self, graph: &Graph) -> impl Iterator<Item = State> + '_ {
|
||||
self.elephant
|
||||
.next(graph, self.valves)
|
||||
.flat_map(move |(new_elephant, elephant_released, valves)| {
|
||||
self.human
|
||||
.next(graph, valves)
|
||||
.map(move |(new_human, human_released, valves)| {
|
||||
self.update(
|
||||
Some(new_human),
|
||||
Some(new_elephant),
|
||||
human_released + elephant_released,
|
||||
valves,
|
||||
)
|
||||
})
|
||||
})
|
||||
.chain_if_empty(self.human.next(graph, self.valves).map(
|
||||
move |(new_human, released, valves)| {
|
||||
self.update(Some(new_human), None, released, valves)
|
||||
},
|
||||
))
|
||||
.chain_if_empty(self.elephant.next(graph, self.valves).map(
|
||||
move |(new_elephant, released, valves)| {
|
||||
self.update(None, Some(new_elephant), released, valves)
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
|
||||
struct ValveState(u64);
|
||||
|
||||
impl ValveState {
|
||||
fn is_valve_open(self, valve: usize) -> bool {
|
||||
self.0 & (1 << valve) != 0
|
||||
}
|
||||
|
||||
fn open_valve(self, valve: usize) -> Self {
|
||||
Self(self.0 | (1 << valve))
|
||||
}
|
||||
}
|
||||
|
||||
struct Graph {
|
||||
rates: Vec<u32>,
|
||||
distances: Vec<Vec<u32>>,
|
||||
start: usize,
|
||||
}
|
||||
|
||||
impl Graph {
|
||||
fn from_input(input: &Input) -> Self {
|
||||
let distances = floyd_warshall(&input.valves);
|
||||
let rates = input.valves.iter().map(|v| v.rate).collect();
|
||||
Self {
|
||||
rates,
|
||||
distances,
|
||||
start: input.start,
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_irrelevant_nodes(mut self) -> Self {
|
||||
let relevant_valves = self
|
||||
.rates
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(i, &rate)| rate > 0 || i == self.start)
|
||||
.map(|(i, _)| i)
|
||||
.collect_vec();
|
||||
self.distances = relevant_valves
|
||||
.iter()
|
||||
.map(|&i| {
|
||||
relevant_valves
|
||||
.iter()
|
||||
.map(|&j| self.distances[i][j])
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
self.start = relevant_valves
|
||||
.iter()
|
||||
.find_position(|&&i| i == self.start)
|
||||
.unwrap()
|
||||
.0;
|
||||
self.rates = relevant_valves.into_iter().map(|i| self.rates[i]).collect();
|
||||
self
|
||||
}
|
||||
|
||||
fn sort_valves(mut self) -> Self {
|
||||
let indices = (0..self.rates.len())
|
||||
.sorted_unstable_by_key(|&i| Reverse(self.rates[i]))
|
||||
.collect_vec();
|
||||
self.distances = indices
|
||||
.iter()
|
||||
.map(|&i| indices.iter().map(|&j| self.distances[i][j]).collect())
|
||||
.collect();
|
||||
self.start = indices
|
||||
.iter()
|
||||
.find_position(|&&i| i == self.start)
|
||||
.unwrap()
|
||||
.0;
|
||||
self.rates = indices.into_iter().map(|i| self.rates[i]).collect();
|
||||
self
|
||||
}
|
||||
|
||||
fn closed_valves(&self, valves: ValveState) -> impl Iterator<Item = (usize, u32)> + '_ {
|
||||
self.rates
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.filter(move |&(i, rate)| !valves.is_valve_open(i) && rate > 0)
|
||||
}
|
||||
}
|
||||
|
||||
fn floyd_warshall(valves: &[Valve]) -> Vec<Vec<u32>> {
|
||||
let mut dist = vec![vec![u32::MAX; valves.len()]; valves.len()];
|
||||
for (i, valve) in valves.iter().enumerate() {
|
||||
for &j in &valve.tunnels {
|
||||
dist[i][j] = 1;
|
||||
}
|
||||
dist[i][i] = 0;
|
||||
}
|
||||
for k in 0..valves.len() {
|
||||
for i in 0..valves.len() {
|
||||
let di = dist[i][k];
|
||||
if di == u32::MAX {
|
||||
continue;
|
||||
}
|
||||
for j in 0..valves.len() {
|
||||
let dj = dist[k][j];
|
||||
if dj == u32::MAX {
|
||||
continue;
|
||||
}
|
||||
dist[i][j] = dist[i][j].min(di + dj);
|
||||
}
|
||||
}
|
||||
}
|
||||
dist
|
||||
}
|
||||
|
||||
Self {
|
||||
valves,
|
||||
dist,
|
||||
dp: FxHashMap::default(),
|
||||
fn solve(input: &Input, human_time: u32, elephant_time: u32) -> u32 {
|
||||
let graph = Graph::from_input(input)
|
||||
.remove_irrelevant_nodes()
|
||||
.sort_valves();
|
||||
|
||||
let mut queue = BinaryHeap::from([State::init(graph.start, human_time, elephant_time)]);
|
||||
let mut seen = FxHashSet::default();
|
||||
let mut out = 0;
|
||||
while let Some(state) = queue.pop() {
|
||||
out = out.max(state.released);
|
||||
|
||||
for next in state.next(&graph) {
|
||||
let closest = graph
|
||||
.closed_valves(next.valves)
|
||||
.map(|(i, _)| graph.distances[next.human.position][i])
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
let time = (next.human.time + 1).saturating_sub(closest);
|
||||
let max = next.released
|
||||
+ graph
|
||||
.closed_valves(next.valves)
|
||||
.take(time as usize / 2)
|
||||
.enumerate()
|
||||
.map(|(i, (_, v))| v * (time - ((i as u32 + 1) * 2)))
|
||||
.sum::<u32>();
|
||||
|
||||
if max < out {
|
||||
continue;
|
||||
}
|
||||
|
||||
if seen.insert(next) {
|
||||
queue.push(next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn solve(&mut self, p: usize, time: u32, closed: u64) -> u32 {
|
||||
let key = (p, time, closed);
|
||||
if let Some(&result) = self.dp.get(&key) {
|
||||
return result;
|
||||
}
|
||||
|
||||
let result = (0..self.valves.len())
|
||||
.filter_map(|q| {
|
||||
if closed & 1 << q == 0 {
|
||||
return None;
|
||||
}
|
||||
if let Some(t) = time.checked_sub(self.dist[p][q] + 1) {
|
||||
Some(self.solve(q, t, closed & !(1 << q)) + self.valves[q].rate * t)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
self.dp.insert(key, result);
|
||||
result
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn part1(input: &Input) -> u32 {
|
||||
Solver::new(input).solve(
|
||||
0,
|
||||
30,
|
||||
input
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, x)| x.rate > 0)
|
||||
.fold(0, |acc, (i, _)| acc | 1 << i),
|
||||
)
|
||||
solve(input, 30, 0)
|
||||
}
|
||||
|
||||
fn part2(input: &Input) -> u32 {
|
||||
let mut solver = Solver::new(input);
|
||||
let valves = input
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, x)| x.rate > 0)
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
(0u64..1 << (valves.len() - 1))
|
||||
.filter(|&s| valves.len().abs_diff(s.count_ones() as usize * 2) <= 1)
|
||||
.map(|s| {
|
||||
let a = solver.solve(
|
||||
0,
|
||||
26,
|
||||
valves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(i, _)| s & 1 << i != 0)
|
||||
.fold(0, |acc, (_, j)| acc | 1 << j),
|
||||
);
|
||||
let b = solver.solve(
|
||||
0,
|
||||
26,
|
||||
valves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(i, _)| s & 1 << i == 0)
|
||||
.fold(0, |acc, (_, j)| acc | 1 << j),
|
||||
);
|
||||
a + b
|
||||
})
|
||||
.max()
|
||||
.unwrap()
|
||||
solve(input, 26, 26)
|
||||
}
|
||||
|
||||
aoc::main!(2022, 16, ex: 1);
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
pub trait IterExt<I>
|
||||
where
|
||||
I: Iterator,
|
||||
{
|
||||
fn take_while_inclusive<P>(self, predicate: P) -> TakeWhileInclusive<I, P>
|
||||
pub trait IterExt: Iterator {
|
||||
fn take_while_inclusive<P>(self, predicate: P) -> TakeWhileInclusive<Self, P>
|
||||
where
|
||||
P: FnMut(&I::Item) -> bool;
|
||||
Self: Sized,
|
||||
P: FnMut(&Self::Item) -> bool;
|
||||
|
||||
fn chain_if_empty<U>(self, other: U) -> ChainIfEmpty<Self, U::IntoIter>
|
||||
where
|
||||
Self: Sized,
|
||||
U: IntoIterator<Item = Self::Item>;
|
||||
}
|
||||
|
||||
impl<I> IterExt<I> for I
|
||||
impl<I> IterExt for I
|
||||
where
|
||||
I: Iterator,
|
||||
{
|
||||
|
@ -17,6 +20,17 @@ where
|
|||
{
|
||||
TakeWhileInclusive::new(self, predicate)
|
||||
}
|
||||
|
||||
fn chain_if_empty<U>(self, other: U) -> ChainIfEmpty<I, U::IntoIter>
|
||||
where
|
||||
U: IntoIterator<Item = Self::Item>,
|
||||
{
|
||||
ChainIfEmpty {
|
||||
iter: self,
|
||||
other: other.into_iter(),
|
||||
state: State::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TakeWhileInclusive<I, P> {
|
||||
|
@ -55,6 +69,43 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
Empty,
|
||||
NotEmpty,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
pub struct ChainIfEmpty<I, U> {
|
||||
iter: I,
|
||||
other: U,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl<I, U> Iterator for ChainIfEmpty<I, U>
|
||||
where
|
||||
I: Iterator,
|
||||
U: Iterator<Item = I::Item>,
|
||||
{
|
||||
type Item = I::Item;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self.state {
|
||||
State::Empty => self.other.next(),
|
||||
State::NotEmpty => self.iter.next(),
|
||||
State::Unknown => match self.iter.next() {
|
||||
Some(x) => {
|
||||
self.state = State::NotEmpty;
|
||||
Some(x)
|
||||
}
|
||||
None => {
|
||||
self.state = State::Empty;
|
||||
self.other.next()
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_take_while_inclusive {
|
||||
use super::*;
|
||||
|
@ -79,3 +130,25 @@ mod tests_take_while_inclusive {
|
|||
test!(always_true, [1, 2, 3, 4], |_| true, vec![1, 2, 3, 4]);
|
||||
test!(always_false, [1, 2, 3, 4], |_| false, vec![1]);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_chain_if_empty {
|
||||
use super::*;
|
||||
|
||||
macro_rules! test {
|
||||
($name:ident, $inp1:expr, $inp2:expr, $exp: expr) => {
|
||||
#[test]
|
||||
fn $name() {
|
||||
assert_eq!(
|
||||
$inp1.into_iter().chain_if_empty($inp2).collect::<Vec<_>>(),
|
||||
$exp
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
test!(both_empty, Vec::<i32>::new(), [], vec![]);
|
||||
test!(first_empty, [], [4, 5, 6], vec![4, 5, 6]);
|
||||
test!(second_empty, [1, 2, 3], [], vec![1, 2, 3]);
|
||||
test!(both_nonempty, [1, 2, 3], [4, 5, 6], vec![1, 2, 3]);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue