238 lines
5.5 KiB
Rust
238 lines
5.5 KiB
Rust
use std::{
|
|
iter,
|
|
ops::{BitAnd, BitOr, Sub},
|
|
};
|
|
|
|
use itertools::Itertools;
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct BitSet {
|
|
bits: Vec<u64>,
|
|
len: usize,
|
|
}
|
|
|
|
impl BitSet {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
pub fn contains(&self, x: usize) -> bool {
|
|
let (i, j) = key(x);
|
|
i < self.bits.len() && self.bits[i] & j != 0
|
|
}
|
|
|
|
pub fn insert(&mut self, x: usize) {
|
|
let (i, j) = key(x);
|
|
self.extend_to_len(i + 1);
|
|
if self.bits[i] & j == 0 {
|
|
self.len += 1;
|
|
}
|
|
self.bits[i] |= j;
|
|
}
|
|
|
|
pub fn remove(&mut self, x: usize) {
|
|
let (i, j) = key(x);
|
|
if i < self.bits.len() {
|
|
if self.bits[i] & j != 0 {
|
|
self.len -= 1;
|
|
}
|
|
self.bits[i] &= !j;
|
|
}
|
|
}
|
|
|
|
pub fn shrink(&mut self) {
|
|
while self.bits.last() == Some(&0) {
|
|
self.bits.pop();
|
|
}
|
|
self.bits.shrink_to_fit();
|
|
}
|
|
|
|
fn extend_to_len(&mut self, len: usize) {
|
|
if len > self.bits.len() {
|
|
self.bits
|
|
.extend(iter::repeat(0).take(len - self.bits.len()));
|
|
}
|
|
}
|
|
|
|
pub fn iter(&self) -> BitSetIter<'_> {
|
|
self.into_iter()
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.len == 0
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.len
|
|
}
|
|
}
|
|
|
|
impl BitAnd<&Self> for BitSet {
|
|
type Output = Self;
|
|
|
|
fn bitand(mut self, rhs: &Self) -> Self::Output {
|
|
let mut removed = self
|
|
.bits
|
|
.iter()
|
|
.skip(rhs.bits.len())
|
|
.map(|x| x.count_ones())
|
|
.sum::<u32>();
|
|
self.bits.truncate(rhs.bits.len());
|
|
for (a, &b) in self.bits.iter_mut().zip(&rhs.bits) {
|
|
removed += (*a & !b).count_ones();
|
|
*a &= b;
|
|
}
|
|
self.len -= removed as usize;
|
|
self
|
|
}
|
|
}
|
|
|
|
impl BitOr<&Self> for BitSet {
|
|
type Output = Self;
|
|
|
|
fn bitor(mut self, rhs: &Self) -> Self::Output {
|
|
self.extend_to_len(rhs.bits.len());
|
|
let mut added = 0;
|
|
for (a, &b) in self.bits.iter_mut().zip(&rhs.bits) {
|
|
added += (!*a & b).count_ones();
|
|
*a |= b;
|
|
}
|
|
self.len += added as usize;
|
|
self
|
|
}
|
|
}
|
|
|
|
impl Sub<&Self> for BitSet {
|
|
type Output = Self;
|
|
|
|
fn sub(mut self, rhs: &Self) -> Self::Output {
|
|
let mut removed = 0;
|
|
for (a, &b) in self.bits.iter_mut().zip(&rhs.bits) {
|
|
removed += (*a & b).count_ones();
|
|
*a &= !b;
|
|
}
|
|
self.len -= removed as usize;
|
|
self
|
|
}
|
|
}
|
|
|
|
impl PartialEq for BitSet {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
if self.len != other.len {
|
|
return false;
|
|
}
|
|
self.bits.iter().zip_longest(&other.bits).all(|x| {
|
|
let (a, b) = x.left_and_right();
|
|
a.unwrap_or(&0) == b.unwrap_or(&0)
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Eq for BitSet {}
|
|
|
|
impl FromIterator<usize> for BitSet {
|
|
fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
|
|
let mut set = Self::new();
|
|
for x in iter {
|
|
set.insert(x);
|
|
}
|
|
set
|
|
}
|
|
}
|
|
|
|
pub struct BitSetIter<'a> {
|
|
i: usize,
|
|
j: usize,
|
|
set: &'a BitSet,
|
|
}
|
|
|
|
impl<'a> IntoIterator for &'a BitSet {
|
|
type Item = <Self::IntoIter as Iterator>::Item;
|
|
|
|
type IntoIter = BitSetIter<'a>;
|
|
|
|
fn into_iter(self) -> Self::IntoIter {
|
|
BitSetIter {
|
|
i: 0,
|
|
j: 0,
|
|
set: self,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Iterator for BitSetIter<'_> {
|
|
type Item = usize;
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
while *self.set.bits.get(self.i)? >> self.j == 0 {
|
|
self.i += 1;
|
|
self.j = 0;
|
|
}
|
|
|
|
while self.set.bits[self.i] & (1 << self.j) == 0 {
|
|
self.j += 1;
|
|
}
|
|
|
|
let x = (self.i << 6) | self.j;
|
|
self.j += 1;
|
|
if self.j == 64 {
|
|
self.j = 0;
|
|
self.i += 1;
|
|
}
|
|
Some(x)
|
|
}
|
|
}
|
|
|
|
fn key(x: usize) -> (usize, u64) {
|
|
(x >> 6, 1 << (x & 0x3f))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn set() {
|
|
let mut set = BitSet::new();
|
|
assert_eq!(set.iter().collect::<Vec<_>>(), []);
|
|
assert!(!set.contains(1));
|
|
assert!(!set.contains(2));
|
|
assert!(!set.contains(3));
|
|
assert!(!set.contains(1337));
|
|
assert_eq!(set.len(), 0);
|
|
set.insert(1);
|
|
set.insert(2);
|
|
set.insert(1337);
|
|
set.insert(3);
|
|
assert!(set.contains(1));
|
|
assert!(set.contains(2));
|
|
assert!(set.contains(3));
|
|
assert!(set.contains(1337));
|
|
assert_eq!(set.len(), 4);
|
|
assert_eq!(set.iter().collect::<Vec<_>>(), [1, 2, 3, 1337]);
|
|
set.remove(2);
|
|
assert_eq!(set.len(), 3);
|
|
assert_eq!(set.iter().collect::<Vec<_>>(), [1, 3, 1337]);
|
|
set.remove(1337);
|
|
assert_eq!(set.len(), 2);
|
|
assert_eq!(set, BitSet::from_iter([1, 3]));
|
|
assert_ne!(set.bits, BitSet::from_iter([1, 3]).bits);
|
|
assert_eq!(set.bits.len(), 21);
|
|
set.shrink();
|
|
assert_eq!(set.bits.len(), 1);
|
|
assert_eq!(set.iter().collect::<Vec<_>>(), [1, 3]);
|
|
}
|
|
|
|
#[test]
|
|
fn ops() {
|
|
let a = BitSet::from_iter([1, 2, 3]);
|
|
let b = BitSet::from_iter([3, 4]);
|
|
|
|
assert_eq!(a.clone() & &b, BitSet::from_iter([3]));
|
|
assert_eq!((a.clone() & &b).len(), 1);
|
|
assert_eq!(a.clone() | &b, BitSet::from_iter([1, 2, 3, 4]));
|
|
assert_eq!((a.clone() | &b).len(), 4);
|
|
assert_eq!(a.clone() - &b, BitSet::from_iter([1, 2]));
|
|
assert_eq!((a.clone() - &b).len(), 2);
|
|
}
|
|
}
|