matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:heavy_check_mark: test/standalone/skiplist_set_test.cpp

Code

#define STANDALONE
#include "include/mtl/skiplist.hpp"

#include <iostream>
#include <vector>
#include <cassert>

int main() {
  const int Max = 4e5;
  std::vector<int> values;
  while (values.empty()) {
    for (int i = 0; i < Max; i++)
      if (rand()%4 == 0)
        values.push_back(i);
  }
  int n = values.size();
  std::vector<int> shuffled = values;
  for (int i = 0; i < n; i++) {
    std::swap(shuffled[i], shuffled[rand()%n]);
  }

  SkiplistSet<int> S;
  for (int v : shuffled) {
    S.insert(v);
  }
//  S.print_for_debug();
  for (int i = 0; i < n; i++) {
    if (*S.get_at(i) != values[i]) {
      std::cout << "get " << i << " " << *S.get_at(i) << " != " << values[i] << std::endl;
      assert(false);
      return 1;
    }
  }
  int target = -1;
  int pred = -1;
  int succ = values[0];
  int k = 0;
  auto log = [&]() {
    std::cout << pred << ' ' << target << ' ' << succ << std::endl;
    assert(false);
  };
  for (int i = 0; i < Max; i++) {
    if (k < n and values[k] == i) {
      target = values[k];
      pred = k-1 >= 0 ? values[k-1] : -1;
      succ = k+1 < n ? values[k+1] : -1;
      k++;
    } else {
      pred = k-1 >= 0 ? values[k-1] : -1;
      target = -1;
    }

    auto fit = S.find(i);
    if (fit != S.end()) {
      if (*fit != i) {
        std::cout << "find: " << i << std::endl;
        log();
        return 1;
      }
    } else {
      if (target != -1) {
        log();
        return 1;
      }
    }

    auto sucit = S.successor(i);
    if (sucit != S.end()) {
      if (*sucit != succ) {
        std::cout << "succ: " << *sucit << std::endl;
        log();
        return 1;
      }
    } else {
      if (succ != -1) {
        log();
        return 1;
      }
    }

    auto predit = S.predecessor(i);
    if (predit != S.end()) {
      if (*predit != pred) {
        std::cout << "pred: " << *predit << std::endl;
        log();
        return 1;
      }
    } else {
      if (pred != -1) {
        log();
        return 1;
      }
    }
  }

  int size = n;
  if ((int) S.size() != size) {
    std::cout << S.size() << ' ' << size<< std::endl;
    log();
    return 1;
  }
  for (int v : shuffled) {
    S.erase(v);
    size--;
    if ((int) S.size() != size) {
      std::cout << S.size() << ' ' << size<< std::endl;
      log();
      return 1;
    }
  }

  std::cout << "OK" << std::endl;
}
#line 1 "test/standalone/skiplist_set_test.cpp"
#define STANDALONE
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#include <cassert>
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif

namespace bm {

/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
  x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
  x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
  x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
  return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
  return std::popcount(x);
#else
  return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
  return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_zero(x);
#else
  return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
  return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_one(x);
#else
  return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
  return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
  return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
  constexpr uint64_t hmask = 0x8080808080808080ull;
  constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
  auto a = x & hmask;
  auto b = x & lmask;
  b = hmask - b;
  b = ~b;
  auto c = (a | b) & hmask;
  c *= 0x0002040810204081ull;
  return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
  uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
  return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
  if (x & 0x80)
    return 8;
  uint64_t p = uint64_t(x) * 0x0101010101010101ull;
  p -= 0x8040201008040201ull;
  p = ~p & 0x8080808080808080ull;
  p = (p >> 7) * 0x0101010101010101ull;
  p >>= 56;
  return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
  assert(x != 0);
  return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
  assert(x != 0);
  return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_zero(x);
#else
  if (x == 0)
    return 64;
  auto i = mssb8(summary(x));
  auto j = mssb8(bextr(x, 8 * i, 8));
  return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
  return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_one(x);
#else
  return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
  return clo(x);
}

inline constexpr unsigned clz8(uint8_t x) {
  return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
  x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
  x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
  x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
  x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
  x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
  x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
  return x;
}

/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::has_single_bit(x);
#else
  return x != 0 && (x & (x - 1)) == 0;
#endif
}

/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::bit_width(x);
#else
  return 64 - clz(x);
#endif
}

/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_ceil(x);
#else
  if (x == 0) return 1;
  return 1ull << bit_width(x - 1);
#endif
}

/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_floor(x);
#else
  if (x == 0) return 0;
  return 1ull << (bit_width(x) - 1);
#endif
}

} // namespace bm
#line 3 "include/mtl/skiplist.hpp"
#include <memory>
#include <vector>
#include <random>
#line 7 "include/mtl/skiplist.hpp"
#include <iostream>

template<typename T>
class Skiplist {
 protected:
  struct Node {
   private:
    struct Path {
      int length;
      std::shared_ptr<Node> next;
      Path(int l, std::shared_ptr<Node> n) : length(l), next(n) {}
    };
   public:
    T v;
    std::vector<std::shared_ptr<Node>> next;
    std::vector<int> length;
    Node(T v, int h) : v(v), next(h+1, nullptr), length(h+1, 0) {}
  };

 public:
  static constexpr int kMaxHeight = 32;

  class iterator {
   public:
    using value_type = T;
    using pointer = T*;
    using reference = T&;
    using iterator_category = std::forward_iterator_tag;
   private:
    std::shared_ptr<Node> ptr_;
   public:
    iterator(std::shared_ptr<Node> ptr) : ptr_(ptr) {}
    T& operator*() {
      return ptr_->v;
    }
    T* operator&() {
      return &(ptr_->v);
    }
    iterator operator++() {
      ptr_ = ptr_->next[0];
      return *this;
    }
    iterator operator++(int) {
      iterator ret = *this;
      operator++();
      return ret;
    }
    bool operator==(iterator r) const {
      return ptr_ == r.ptr_;
    }
    bool operator!=(iterator r) const {
      return ptr_ != r.ptr_;
    }
  };

 protected:
  std::shared_ptr<Node> sentinel_;
  int height_;
  size_t size_;

  std::default_random_engine rnd_gen;
  std::uniform_int_distribution<uint32_t> dist;

 public:
  Skiplist() :
      sentinel_(std::make_shared<Node>(T(), kMaxHeight)),
      height_(0),
      size_(0),
      rnd_gen(std::random_device()()),
      dist(0, (1ull<<kMaxHeight)-1) {}

  template<class InputIt>
  Skiplist(InputIt first, InputIt last) : Skiplist() {
    // TODO: Optimize
    // Howto: Handle vector that end pointers for each height.
    for (auto it = first; it != last; it++)
      insert_at(size(), *it);
  }

  size_t size() const {return size_;}
  bool empty() const {return size() == 0;}

  iterator begin() const {
    return iterator(sentinel_->next[0]);
  }
  iterator end() const {
    return iterator(nullptr);
  }
  
  iterator get_at(size_t i) const {
    if (size() <= i)
      return end();
    auto u = sentinel_;
    i++;
    for (int r = height_; r >= 0; r--) {
      while (u->next[r] and u->length[r] < (int) i) {
        i -= u->length[r];
        u = u->next[r];
      }
      if (u->next[r] and u->length[r] == (int) i)
        return iterator(u->next[r]);
    }
    assert(false);
    return end();
  }
  iterator set_at(int i, T v) {
    auto u = get_at(i);
    if (u != end())
      u->v = v;
    return u;
  }

  int pick_height() {
    return bm::ctz(dist(rnd_gen));
  }

  iterator insert_at(int i, T v) {
    if (i > size())
      return end();
    i++;
    auto u = sentinel_;
    int hw = pick_height();
    if (hw > height_)
      height_ = hw;
    auto w = std::make_shared<Node>(v, hw);
    for (int r = height_; r >= 0; r--) {
      while (u->next[r] and u->length[r] < i) {
        i -= u->length[r];
        u = u->next[r];
      }
      if (r <= hw) {
        w->length[r] = u->next[r] ? u->length[r] - (i-1) : 0;
        u->length[r] = i;
        w->next[r] = u->next[r];
        u->next[r] = w;
      } else if (u->next[r]) {
        u->length[r]++;
      }
    }
    size_++;
    return iterator(w);
  }

  iterator erase_at(int i) {
    if (size() <= i)
      return end();
    auto u = sentinel_;
    i++;
    for (int r = height_; r >= 0; r--) {
      while (u->next[r] and u->length[r] < i) {
        i -= u->length[r];
        u = u->next[r];
      }
      if (u->next[r] and u->length[r] == i) {
        if (u->next[r]->length[r]) {
          u->length[r] = u->length[r] + u->next[r]->length[r] - 1;
        } else {
          u->length[r] = 0;
        }
        u->next[r] = u->next[r]->next[r];
      } else if (u->next[r]) {
        u->length[r]--;
      }
    }
    size_--;
    return iterator(u->next[0]);
  }
};
template<typename T>
constexpr int Skiplist<T>::kMaxHeight;


template<typename T>
class SkiplistSet : public Skiplist<T> {
 private:
  using _base = Skiplist<T>;
  using typename _base::Node;
  using typename _base::iterator;
  std::vector<std::shared_ptr<Node>> stack_;
  std::vector<size_t> idx_;
 public:
  SkiplistSet() : Skiplist<T>(),
                  stack_(_base::kMaxHeight+1, _base::sentinel_),
                  idx_(_base::kMaxHeight+1, 0) {}

  void print_for_debug() const {
    for (int r = _base::height_; r >= 0; r--){
      auto u = _base::sentinel_;
      int l = u->length[r];
      std::cout << "-"<<u->length[r] << "-";
      for (int i = 0; i < l-1; i++)
        std::cout<<"    ";
      u = u->next[r];
      while (u) {
        l = u->length[r];
        std::cout << "|-"<<u->length[r] << "-";
        for (int i = 0; i < l-1; i++)
          std::cout<<"    ";
        u = u->next[r];
      }
      std::cout<<std::endl;
    }
    auto u = _base::sentinel_;
    u = u->next[0];
    std::cout<<" ";
    while (u) {
      std::cout << "  "<<u->v<<' ';
      u = u->next[0];
    }
    std::cout<<std::endl;
  }

  iterator find(const T& v) const {
    auto u = _base::sentinel_;
    for (int r = _base::height_; r >= 0; r--) {
      while (u->next[r] and u->next[r]->v < v)
        u = u->next[r];
      if (u->next[r] and u->next[r]->v == v)
        return iterator(u->next[r]);
    }
    return _base::end();
  }
  size_t count(const T& v) const {
    return size_t(find(v) != _base::end());
  }
  iterator lower_bound(const T& v) const {
    auto u = _base::sentinel_;
    for (int r = _base::height_; r >= 0; r--) {
      while (u->next[r] and u->next[r]->v < v)
        u = u->next[r];
      if (u->next[r] and u->next[r]->v == v)
        return iterator(u->next[r]);
    }
    return iterator(u->next[0]);
  }
  iterator upper_bound(const T& v) const {
    auto u = _base::sentinel_;
    for (int r = _base::height_; r >= 0; r--) {
      while (u->next[r] and u->next[r]->v <= v)
        u = u->next[r];
    }
    return iterator(u->next[0]);
  }
  iterator successor(const T& v) const {
    return upper_bound(v);
  }
  iterator predecessor(const T& v) const {
    auto u = _base::sentinel_;
    for (int r = _base::height_; r >= 0; r--) {
      while (u->next[r] and u->next[r]->v < v)
        u = u->next[r];
    }
    return u != _base::sentinel_ ? iterator(u) : _base::end();
  }

  template<class... Args>
  std::pair<iterator, bool> emplace(Args&&... args) {
    auto u = _base::sentinel_;
    size_t j = 1;
    T v(std::forward<Args>(args)...);
    for (int r = _base::height_; r >= 0; r--) {
      while (u->next[r] and u->next[r]->v < v) {
        j += u->length[r];
        u = u->next[r];
      }
      if (u->next[r] and u->next[r]->v == v)
        return std::make_pair(iterator(u->next[r]), false);
      stack_[r] = u;
      idx_[r] = j-1;
    }
    int hw = _base::pick_height();
    if (_base::height_ < hw) {
      _base::height_ = hw;
    }
    auto w = std::make_shared<Node>(std::move(v), hw);
    for (int r = _base::height_; r >= 0; r--) {
      if (r <= hw) {
        assert(idx_[r] < j);
        if (stack_[r]->next[r]) {
          w->length[r] = stack_[r]->length[r] - (j - idx_[r]) + 1;
        } else {
          w->length[r] = 0;
        }
        stack_[r]->length[r] = j - idx_[r];
        w->next[r] = stack_[r]->next[r];
        stack_[r]->next[r] = w;
      } else if (stack_[r]->next[r]) {
        if (stack_[r]->next[r])
          stack_[r]->length[r]++;
      }
    }
    _base::size_++;
    return std::make_pair(iterator(w), true);
  }
  std::pair<iterator, bool> insert(const T& v) {
    return emplace(v);
  }
  std::pair<iterator, bool> insert(T&& v) {
    return emplace(std::move(v));
  }

  iterator erase(const T& v) {
    bool erased = false;
    auto u = _base::sentinel_;
    int r = _base::height_;
    for (; r >= 0; r--) {
      while (u->next[r] and u->next[r]->v < v)
        u = u->next[r];
      if (u->next[r] and u->next[r]->v == v) {
        erased = true;
        break;
      }
      stack_[r] = u;
    }
    if (erased) {
      _base::size_--;
      for (int i = _base::height_; i > r; i--)
        if (stack_[i]->next[i])
          stack_[i]->length[i]--;
      for(; r >= 0; r--) {
        if (u->next[r] and u->next[r]->v == v) {
          if (u->next[r]->next[r]) {
            u->length[r] = u->length[r] + u->next[r]->length[r] - 1;
          } else {
            u->length[r] = 0;
          }
          u->next[r] = u->next[r]->next[r];
        }
      }
    }
    return iterator(u->next[0]);
  }
};
#line 3 "test/standalone/skiplist_set_test.cpp"

#line 7 "test/standalone/skiplist_set_test.cpp"

int main() {
  const int Max = 4e5;
  std::vector<int> values;
  while (values.empty()) {
    for (int i = 0; i < Max; i++)
      if (rand()%4 == 0)
        values.push_back(i);
  }
  int n = values.size();
  std::vector<int> shuffled = values;
  for (int i = 0; i < n; i++) {
    std::swap(shuffled[i], shuffled[rand()%n]);
  }

  SkiplistSet<int> S;
  for (int v : shuffled) {
    S.insert(v);
  }
//  S.print_for_debug();
  for (int i = 0; i < n; i++) {
    if (*S.get_at(i) != values[i]) {
      std::cout << "get " << i << " " << *S.get_at(i) << " != " << values[i] << std::endl;
      assert(false);
      return 1;
    }
  }
  int target = -1;
  int pred = -1;
  int succ = values[0];
  int k = 0;
  auto log = [&]() {
    std::cout << pred << ' ' << target << ' ' << succ << std::endl;
    assert(false);
  };
  for (int i = 0; i < Max; i++) {
    if (k < n and values[k] == i) {
      target = values[k];
      pred = k-1 >= 0 ? values[k-1] : -1;
      succ = k+1 < n ? values[k+1] : -1;
      k++;
    } else {
      pred = k-1 >= 0 ? values[k-1] : -1;
      target = -1;
    }

    auto fit = S.find(i);
    if (fit != S.end()) {
      if (*fit != i) {
        std::cout << "find: " << i << std::endl;
        log();
        return 1;
      }
    } else {
      if (target != -1) {
        log();
        return 1;
      }
    }

    auto sucit = S.successor(i);
    if (sucit != S.end()) {
      if (*sucit != succ) {
        std::cout << "succ: " << *sucit << std::endl;
        log();
        return 1;
      }
    } else {
      if (succ != -1) {
        log();
        return 1;
      }
    }

    auto predit = S.predecessor(i);
    if (predit != S.end()) {
      if (*predit != pred) {
        std::cout << "pred: " << *predit << std::endl;
        log();
        return 1;
      }
    } else {
      if (pred != -1) {
        log();
        return 1;
      }
    }
  }

  int size = n;
  if ((int) S.size() != size) {
    std::cout << S.size() << ' ' << size<< std::endl;
    log();
    return 1;
  }
  for (int v : shuffled) {
    S.erase(v);
    size--;
    if ((int) S.size() != size) {
      std::cout << S.size() << ' ' << size<< std::endl;
      log();
      return 1;
    }
  }

  std::cout << "OK" << std::endl;
}
Back to top page