matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:heavy_check_mark: test/standalone/xft_test.cpp

Depends on

Code

#define STANDALONE
#include "include/mtl/xft.hpp"
#include "set_test.hpp"

int main() {
  mtl::integer_set_test<XFastTrieSet<unsigned, 20>, 1<<20>();
  mtl::integer_set_test<XFastTrieSet<unsigned, 20>, 1<<20, false>();
  mtl::map_emplace_test<XFastTrieMap<unsigned, std::vector<int>>>();
  std::cout << "OK" << std::endl;
}
#line 1 "test/standalone/xft_test.cpp"
#define STANDALONE
#line 2 "include/mtl/traits/set_traits.hpp"
#include <cstddef>
#include <initializer_list>
#include <type_traits>
#include <iterator>

namespace traits {

template<typename T, typename M>
struct AssociativeArrayDefinition {
  using key_type = T;
  using mapped_type = M;
  using value_type = std::pair<T const, M>;
  using raw_key_type = typename std::remove_const<T>::type;
  using raw_mapped_type = typename std::remove_const<M>::type;
  using init_type = std::pair<raw_key_type, raw_mapped_type>;
  using moved_type = std::pair<raw_key_type&&, raw_mapped_type&&>;
  template<class K, class V>
  static key_type const& key_of(std::pair<K,V> const& kv) {
    return kv.first;
  }
};
template<typename T>
struct AssociativeArrayDefinition<T, void> {
  using key_type = T;
  using value_type = T;
  using init_type = T;
  static key_type const& key_of(value_type const& k) { return k; }
};

template<class T, typename = std::void_t<>>
struct get_const_iterator {
  using base = typename T::iterator;
  struct type : base {
    type(const base& r) : base(r) {}
    type(base&& r) : base(std::move(r)) {}
  };
};
template<class T>
struct get_const_iterator<T, std::void_t<typename T::const_iterator>> {
  using type = typename T::const_iterator;
};

#if __cplusplus >= 202002L
#include <concepts>
template<class M>
concept IsAssociativeArray = requires (M m) {
  typename M::key_type;
  typename M::value_type;
  typename M::iterator;
  {m.size()} -> std::convertible_to<size_t>;
  {m.empty()} -> std::same_as<bool>;
  {m.clear()};
  {m.begin()} -> std::same_as<typename M::iterator>;
  {m.end()} -> std::same_as<typename M::iterator>;
};
#endif

template<class Base>
#if __cplusplus >= 202002L
requires IsAssociativeArray<Base>
#endif
class SetTraitsBase : public Base {
 public:
  using key_type = typename Base::key_type;
  using value_type = typename Base::value_type;
  using init_type = typename Base::init_type;
  using iterator = typename Base::iterator;
  SetTraitsBase() = default;
  template<typename InputIt>
  explicit SetTraitsBase(InputIt begin, InputIt end) : Base(begin, end) {
    static_assert(std::is_convertible<typename std::iterator_traits<InputIt>::value_type, value_type>::value, "");
  }
  SetTraitsBase(std::initializer_list<value_type> init) : Base(init.begin(), init.end()) {}
  using Base::size;
  bool empty() const { return size() == 0; }
  using Base::clear;
  using const_iterator = typename get_const_iterator<Base>::type;
  iterator begin() {
    return Base::begin();
  }
  iterator end() {
    return Base::end();
  }
  const_iterator begin() const {
    return const_iterator(Base::begin());
  }
  const_iterator end() const {
    return const_iterator(Base::end());
  }
  const_iterator cbegin() const {
    return begin();
  }
  const_iterator cend() const {
    return end();
  }
  using reverse_iterator = std::reverse_iterator<iterator>;
  using reverse_const_iterator = std::reverse_iterator<const_iterator>;
  reverse_iterator rbegin() {
    return std::make_reverse_iterator(end());
  }
  reverse_iterator rend() {
    return std::make_reverse_iterator(begin());
  }
  reverse_const_iterator rbegin() const {
    return std::make_reverse_iterator(end());
  }
  reverse_const_iterator rend() const {
    return std::make_reverse_iterator(begin());
  }
  reverse_const_iterator crbegin() const {
    return rbegin();
  }
  reverse_const_iterator crend() const {
    return rend();
  }
  template<class Key>
  iterator lower_bound(const Key& x) const {
    return Base::_lower_bound(x);
  }
  iterator lower_bound(const key_type& x) const {
    return Base::_lower_bound(x);
  }
  template<class Key>
  iterator upper_bound(const Key& x) const {
    return Base::_upper_bound(x);
  }
  iterator upper_bound(const key_type& x) const {
    return Base::_upper_bound(x);
  }
  template<class Key>
  iterator find(const Key& x) {
    return Base::_find(x);
  }
  iterator find(const key_type& x) {
    return Base::_find(x);
  }
  template<class Key>
  const_iterator find(const Key& x) const {
    return Base::_find(x);
  }
  const_iterator find(const key_type& x) const {
    return Base::_find(x);
  }
  template<class Key>
  size_t count(const Key& x) const {
    return find(x) != end();
  }
  size_t count(const key_type& x) const {
    return find(x) != end();
  }
  std::pair<iterator, bool> insert(const init_type& v) {
    return Base::_insert(v);
  }
  std::pair<iterator, bool> insert(init_type&& v) {
    return Base::_insert(std::move(v));
  }
  template<typename=void>
  std::pair<iterator, bool> insert(const value_type& v) {
    return Base::_insert(v);
  }
  template<typename=void>
  std::pair<iterator, bool> insert(value_type&& v) {
    return Base::_insert(std::move(v));
  }
  template<class... Args>
  std::pair<iterator, bool> emplace(Args&&... args) {
    using emplace_type = typename std::conditional<
        std::is_constructible<init_type, Args...>::value,
        init_type,
        value_type
    >::type;
    return Base::_insert(emplace_type(std::forward<Args>(args)...));
  }
  template<class... Args>
  iterator emplace_hint(const_iterator hint, Args&&... args) {
    using emplace_type = typename std::conditional<
        std::is_constructible<init_type, Args...>::value,
        init_type,
        value_type
    >::type;
    return Base::_emplace_hint(hint, emplace_type(std::forward<Args>(args)...));
  }
  size_t erase(const key_type& x) {
    return Base::_erase(x);
  }
  iterator erase(iterator it) {
    return Base::_erase(it);
  }
  iterator erase(const_iterator it) {
    return Base::_erase(it);
  }
};

template<typename Base>
using SetTraits = SetTraitsBase<Base>;

template<typename Base>
class MapTraits : public SetTraitsBase<Base> {
  using SBase = SetTraitsBase<Base>;
 public:
  using typename SBase::key_type;
  using typename SBase::mapped_type;
  using typename SBase::value_type;
  using reference = mapped_type&;
  MapTraits() = default;
  template<typename InputIt>
  explicit MapTraits(InputIt begin, InputIt end) : SBase(begin, end) {}
  MapTraits(std::initializer_list<value_type> init) : SBase(init) {}
  template<typename Key>
  reference operator[](Key&& x) {
    auto i = SBase::lower_bound(x);
    if (i == SBase::end() || x < i->first) {
      i = SBase::emplace_hint(i, std::forward<Key>(x), mapped_type());
    }
    return i->second;
  }
  reference operator[](const key_type& x) {
    auto i = SBase::lower_bound(x);
    if (i == SBase::end() || x < i->first) {
      i = SBase::emplace_hint(i, x, mapped_type());
    }
    return i->second;
  }
  reference operator[](key_type&& x) {
    auto i = SBase::lower_bound(x);
    if (i == SBase::end() || x < i->first) {
      i = SBase::emplace_hint(i, std::move(x), mapped_type());
    }
    return i->second;
  }
};

} // namespace traits
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#include <cassert>
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif

namespace bm {

/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
  x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
  x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
  x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
  return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
  return std::popcount(x);
#else
  return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
  return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_zero(x);
#else
  return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
  return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_one(x);
#else
  return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
  return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
  return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
  constexpr uint64_t hmask = 0x8080808080808080ull;
  constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
  auto a = x & hmask;
  auto b = x & lmask;
  b = hmask - b;
  b = ~b;
  auto c = (a | b) & hmask;
  c *= 0x0002040810204081ull;
  return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
  uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
  return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
  if (x & 0x80)
    return 8;
  uint64_t p = uint64_t(x) * 0x0101010101010101ull;
  p -= 0x8040201008040201ull;
  p = ~p & 0x8080808080808080ull;
  p = (p >> 7) * 0x0101010101010101ull;
  p >>= 56;
  return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
  assert(x != 0);
  return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
  assert(x != 0);
  return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_zero(x);
#else
  if (x == 0)
    return 64;
  auto i = mssb8(summary(x));
  auto j = mssb8(bextr(x, 8 * i, 8));
  return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
  return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_one(x);
#else
  return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
  return clo(x);
}

inline constexpr unsigned clz8(uint8_t x) {
  return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
  x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
  x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
  x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
  x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
  x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
  x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
  return x;
}

/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::has_single_bit(x);
#else
  return x != 0 && (x & (x - 1)) == 0;
#endif
}

/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::bit_width(x);
#else
  return 64 - clz(x);
#endif
}

/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_ceil(x);
#else
  if (x == 0) return 1;
  return 1ull << bit_width(x - 1);
#endif
}

/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_floor(x);
#else
  if (x == 0) return 0;
  return 1ull << (bit_width(x) - 1);
#endif
}

} // namespace bm
#line 4 "include/mtl/binary_trie.hpp"
#include <array>
#include <memory>
#line 9 "include/mtl/binary_trie.hpp"
#include <algorithm>
#line 11 "include/mtl/binary_trie.hpp"

template<typename T, typename M,
    int8_t W = sizeof(T) * 8>
class BinaryTrieBase : public traits::AssociativeArrayDefinition<T, M> {
  static_assert(std::is_unsigned<T>::value, "");
 public:
  using types = traits::AssociativeArrayDefinition<T, M>;
  using key_type = typename types::key_type;
  using value_type = typename types::value_type;
  using init_type = typename types::init_type;
  struct Node;
  using node_ptr = std::shared_ptr<Node>;
  using node_weak_ptr = std::weak_ptr<Node>;
  struct Leaf;
  using leaf_ptr = std::shared_ptr<Leaf>;
  struct Node {
    std::array<node_ptr, 2> c;
    leaf_ptr jump;
    node_weak_ptr parent;
    Node() = default;
    node_ptr& left() { return c[0]; }
    node_ptr& right()  { return c[1]; }
  };
  struct Leaf : Node {
    value_type v;
    Leaf() = default;
    Leaf(const value_type& v) : Node(), v(v) {}
    Leaf(value_type&& v) : Node(), v(std::forward<value_type>(v)) {}
    key_type key() const {
      return types::key_of(v);
    }
    using Node::c;
    leaf_ptr prev() const {
      return std::static_pointer_cast<Leaf>(c[0]);
    }
    leaf_ptr next() const {
      return std::static_pointer_cast<Leaf>(c[1]);
    }
    void set_prev(leaf_ptr l) {
      c[0] = std::static_pointer_cast<Node>(l);
    }
    void set_next(leaf_ptr l) {
      c[1] = std::static_pointer_cast<Node>(l);
    }
  };
 protected:
  node_ptr root_;
  leaf_ptr dummy_;
  size_t size_;
  virtual void _init() {
    root_ = create_node_at(0, 0);
    dummy_ = std::make_shared<Leaf>();
    root_->jump = dummy_;
    dummy_->set_next(dummy_);
    dummy_->set_prev(dummy_);
    size_ = 0;
  }
  void _deinit() {
    root_ = nullptr;
    auto u = dummy_->next();
    dummy_->set_next(nullptr);
    u->set_prev(nullptr);
    while (u != dummy_) {
      auto n = u->next();
      u->set_next(nullptr);
      n->set_prev(nullptr);
      u = n;
    }
    dummy_ = nullptr;
  }
 public:
  BinaryTrieBase() {
    _init();
  }
  BinaryTrieBase(const BinaryTrieBase& rhs) {
    _insert_init(rhs.begin(), rhs.end());
  }
  virtual BinaryTrieBase& operator=(const BinaryTrieBase& rhs) {
    _deinit();
    _insert_init(rhs.begin(), rhs.end());
    return *this;
  }
  BinaryTrieBase(BinaryTrieBase&&) noexcept = default;
  virtual BinaryTrieBase& operator=(BinaryTrieBase&& rhs) noexcept {
    _deinit();
    root_ = std::move(rhs.root_);
    dummy_ = std::move(rhs.dummy_);
    size_ = std::move(rhs.size_);
    return *this;
  }
  virtual ~BinaryTrieBase() {
    _deinit();
  }
 protected:
  template<class InputIt>
  void _insert_init(InputIt begin, InputIt end) {
    static_assert(std::is_convertible<typename std::iterator_traits<InputIt>::value_type, value_type>::value, "");
    _init();
    if (begin == end) return;
    if (!std::is_sorted(begin, end, [](auto& l, auto& r) {
      return types::key_of(l) < types::key_of(r);
    })) {
      for (auto it = begin; it != end; ++it)
        _insert(*it);
      return;
    }
    auto push_link = [&](leaf_ptr l) {
      auto b = dummy_->prev();
      l->set_prev(b);
      l->set_next(dummy_);
      l->prev()->set_next(l);
      l->next()->set_prev(l);
    };
    std::array<node_ptr, W> us{};
    auto grow = [&](key_type x, int k, leaf_ptr l) {
      for (int i = k; i < W-1; i++) {
        us[i+1] = create_node_at(x, i+1);
        int c = (x >> (W-i-1)) & 1;
        us[i]->c[c] = us[i+1];
        us[i+1]->parent = us[i];
        us[i+1]->jump = l;
      }
      int c = x & 1;
      us[W-1]->c[c] = l;
      l->parent = us[W-1];
    };
    us[0] = root_;
    key_type x = types::key_of(*begin);
    auto l = create_leaf_at(x, *begin);
    push_link(l);
    us[0]->jump = l;
    grow(x, 0, l);
    size_t sz = 1;
    for (auto it = std::next(begin); it != end; ++it) {
      key_type px = x;
      x = types::key_of(*it);
      auto m = x ^ px;
      if (m == 0) continue;
//      [[assume(m != 0)]]
      int k = W-1;
      while (m > 1) {
        m >>= 1;
        --k;
      }
      l = create_leaf_at(x, *it);
      push_link(l);
      for (int i = 0; i < k; i++)
        if (!us[i]->c[1]) us[i]->jump = l;
      us[k]->jump = nullptr;
      grow(x, k, l);
      ++sz;
    }
    size_ = sz;
  }
 public:
  template<typename InputIt>
  explicit BinaryTrieBase(InputIt begin, InputIt end) {
    _insert_init(begin, end);
  }
  size_t size() const {
    return size_;
  }
  bool empty() const { return size() == 0; }
  void clear() {
    _deinit();
    _init();
  }
 protected:
  template<bool> struct iterator_base;
 public:
  using iterator = iterator_base<false>;
  using const_iterator = iterator_base<true>;
  iterator begin() {
    return iterator(dummy_->next());
  }
  iterator end() {
    return iterator(dummy_);
  }
  const_iterator begin() const {
    return const_iterator(dummy_->next());
  }
  const_iterator end() const {
    return const_iterator(dummy_);
  }
  template<class Rule>
  const_iterator traverse(Rule rule) const {
    auto u = root_;
    for (int i = 0; i < W; i++) {
      auto l = (bool)u->c[0];
      auto r = (bool)u->c[1];
      auto c = rule(W-1-i, l, r);
      u = u->c[c];
    }
    return const_iterator(std::static_pointer_cast<Leaf>(u));
  }
 protected:
  virtual std::pair<int, node_ptr> _traverse(const key_type& key, 
                                             int depth = 0, 
                                             node_ptr root = nullptr) const {
    int i, c;
    key_type x = key;
    auto u = !root ? root_ : root;
    for (i = depth; i < W; i++) {
      c = (x >> (W-i-1)) & 1;
      if (!u->c[c]) break;
      u = u->c[c];
    }
    return std::make_pair(i, u);
  }
  iterator _lower_bound(const key_type& x) const {
    auto reached = _traverse(x);
    int i = reached.first;
    node_ptr u = reached.second;
    if (i == W) return iterator(std::static_pointer_cast<Leaf>(u));
    auto l = (((x >> (W-i-1)) & 1) == 0) ? u->jump : u->jump->next();
    return iterator(l);
  }
  iterator _upper_bound(const key_type& x) const {
    auto it = _lower_bound(x);
    if (types::key_of(*it) == x)
      ++it;
    return it;
  }
  virtual iterator _find(const key_type& x) const {
    auto reached = _traverse(x);
    int i = reached.first;
    node_ptr u = reached.second;
    if (i == W)
      return iterator(std::static_pointer_cast<Leaf>(u));
    else
      return end();
  }
  virtual node_ptr create_node_at(const key_type&, int) {
    return std::make_shared<Node>();
  }
  virtual leaf_ptr create_leaf_at(const key_type&, const init_type& value) {
    return std::make_shared<Leaf>(value);
  }
  virtual leaf_ptr create_leaf_at(const key_type&, init_type&& value) {
    return std::make_shared<Leaf>(std::move(value));
  }
  template<typename Value>
  iterator _emplace_impl(key_type x, int height, node_ptr forked, Value&& value) {
    assert(height < W);
    int i = height;
    node_ptr u = forked;
    auto f = u;
    int c = (x >> (W-i-1)) & 1;
    auto fc = c;
    auto fi = i;
    auto pred = c == 1 ? u->jump : u->jump->prev();
    u->jump = nullptr;
    auto l = create_leaf_at(x, std::forward<Value>(value));
    l->set_prev(pred);
    l->set_next(pred->next());
    l->prev()->set_next(l);
    l->next()->set_prev(l);
    for (; i < W-1; i++) {
      c = (x >> (W-i-1)) & 1;
      assert(!u->c[c]);
      u->c[c] = create_node_at(x, i+1);
      u->c[c]->parent = u;
      u->c[c]->jump = l;
      u = u->c[c];
    }
    {
      c = (x >> (W-i-1)) & 1;
      u->c[c] = l;
      u->c[c]->parent = u;
    }
    if (f == root_) [[unlikely]] {
      f->jump = l;
    } else [[likely]] {
      auto v = f->parent.lock();
      fi--;
      while (v) {
        c = x >> (W-fi-1) & 1;
        if (c != fc and !v->jump)
          break;
        if (!v->c[fc])
          v->jump = l;
        v = v->parent.lock();
        fi--;
      }
    }
    size_++;
    return iterator(l);
  }
  template<typename Value>
  std::pair<iterator, bool> _insert(Value&& value) {
    static_assert(std::is_convertible<Value, value_type>::value, "");
    key_type x = types::key_of(value);
    auto reached = _traverse(x);
    int i = reached.first;
    node_ptr u = reached.second;
    if (i == W)
      return std::make_pair(iterator(std::static_pointer_cast<Leaf>(u)), false);
    return std::make_pair(_emplace_impl(x, i, u, std::forward<Value>(value)), true);
  }
  virtual std::pair<int, node_ptr> climb_to_lca(leaf_ptr l, key_type x) {
    key_type m = x ^ types::key_of(l->v);
    if (m == 0)
      return std::make_pair(W, std::static_pointer_cast<Node>(l));
    int h = bm::clz(m) - (64 - W);
    node_ptr f = std::static_pointer_cast<Node>(l);
    for (int i = W; i > h; i--)
      f = f->parent.lock();
    return std::make_pair(h, f);
  }
  template<class Value>
  iterator _emplace_hint(const_iterator hint, Value&& value) {
    key_type x = types::key_of(value);
    if (empty())
      return _emplace_impl(x, 0, root_, std::forward<Value>(value));
    if (hint == end())
      --hint;
    int h;
    node_ptr f;
    std::tie(h, f) = climb_to_lca(hint.ptr_, x);
    std::tie(h, f) = _traverse(x, h, f);
    if (h == W)
      return iterator(std::static_pointer_cast<Leaf>(f));
    return _emplace_impl(x, h, f, std::forward<Value>(value));
  }

  virtual void erase_node_at(const key_type&, int, node_ptr) {}
  virtual bool _erase(const key_type& key) {
    auto it = _find(key);
    if (it != end()) {
      _erase_from_leaf(types::key_of(*it), it.ptr_);
      return true;
    } else {
      return false;
    }
  }
  template<typename Key>
  iterator _erase_from_leaf(Key&& key, leaf_ptr l) {
    static_assert(std::is_convertible<Key, key_type>::value, "");
    key_type x = std::forward<Key>(key);
    assert(x == l->key());
    l->prev()->set_next(l->next());
    l->next()->set_prev(l->prev());
    int i,c;
    auto v = std::static_pointer_cast<Node>(l);
    for (i = W-1; i >= 0; i--) {
      erase_node_at(x, i+1, v);
      v = v->parent.lock();
      c = (x >> (W-i-1)) & 1;
      v->c[c] = nullptr;
      if (v->c[c^1]) break;
    }
    auto nj = c ? l->prev() : l->next();
    auto fc = c;
    v->jump = nj;
    v = v->parent.lock();
    i--;
    for (; i >= 0; i--) {
      assert(v);
      c = (x >> (W-i-1)) & 1;
      if (c != fc) {
        if (!v->jump) break;
        v->jump = nj;
      }
      v = v->parent.lock();
    }
    size_--;
    return iterator(l->next());
  }
  iterator iterator_remove_const(const const_iterator& it) {
    return iterator(it.ptr_);
  }
  iterator iterator_remove_const(const_iterator&& it) {
    return iterator(std::move(it.ptr_));
  }
  iterator _erase(iterator it) {
    if (it == end()) return it;
    return _erase_from_leaf(types::key_of(*it), it.ptr_);
  }
  iterator _erase(const_iterator it) {
    if (it == end()) return iterator_remove_const(it);
    return _erase_from_leaf(types::key_of(*it), it.ptr_);
  }
 protected:
  template<bool Const>
  struct iterator_base {
    using difference_type = ptrdiff_t;
    using value_type = BinaryTrieBase::value_type;
    using pointer = typename std::conditional<Const,
                                              const value_type*,
                                              value_type*>::type;
    using reference = typename std::conditional<Const,
                                                const value_type&,
                                                value_type&>::type;
    using iterator_category = std::bidirectional_iterator_tag;
    leaf_ptr ptr_;
    iterator_base(leaf_ptr p) : ptr_(p) {}
    template<bool C>
    iterator_base(const iterator_base<C>& rhs) : ptr_(rhs.ptr_) {}
    template<bool C>
    iterator_base& operator=(const iterator_base<C>& rhs) {
      ptr_ = rhs.ptr_;
    }
    template<bool C>
    iterator_base(iterator_base<C>&& rhs) : ptr_(std::move(rhs.ptr_)) {}
    template<bool C>
    iterator_base& operator=(iterator_base<C>&& rhs) {
      ptr_ = std::move(rhs.ptr_);
    }
    reference operator*() {
      return ptr_->v;
    }
    pointer operator->() {
      return &(ptr_->v);
    }
    template<bool C>
    bool operator==(const iterator_base<C>& rhs) const {
      return ptr_ == rhs.ptr_;
    }
    template<bool C>
    bool operator!=(const iterator_base<C>& rhs) const {
      return !operator==(rhs);
    }
    iterator_base& operator++() {
      ptr_ = ptr_->next();
      return *this;
    }
    iterator_base operator++(int) const {
      iterator_base ret = *this;
      operator++();
      return ret;
    }
    iterator_base& operator--() {
      ptr_ = ptr_->prev();
      return *this;
    }
    iterator_base operator--(int) const {
      iterator_base ret = *this;
      operator--();
      return ret;
    }
  };
};

template<typename T, typename V, uint8_t W = sizeof(T)*8>
using BinaryTrie = traits::MapTraits<BinaryTrieBase<T, V, W>>;
template<typename T, uint8_t W = sizeof(T)*8>
using BinaryTrieSet = traits::SetTraits<BinaryTrieBase<T, void, W>>;
template<typename T, typename V, uint8_t W = sizeof(T)*8>
using BinaryTrieMap = BinaryTrie<T, V, W>;
#line 7 "include/mtl/xft.hpp"
#include <unordered_map>
#line 9 "include/mtl/xft.hpp"

template<class T, class M, int8_t W>
using XFastTrieHashTableMappedType = typename BinaryTrieBase<T, M, W>::node_ptr;
#define XFT_DEFAULT_HASH_TABLE        std::unordered_map
#define XFT_HASH_TABLE_TYPE(HT,T,M,W) HT<T, XFastTrieHashTableMappedType<T, M, W>>

template<typename T, typename M,
    int8_t W = sizeof(T) * 8,
    class HashTable = XFT_HASH_TABLE_TYPE(XFT_DEFAULT_HASH_TABLE, T, M, W)>
class XFastTrieBase : public BinaryTrieBase<T, M, W> {
  static_assert(std::is_unsigned<T>::value, "");
  using Base = BinaryTrieBase<T, M, W>;
 public:
  using hash_table_type = HashTable;
  using types = typename Base::types;
  using value_type = typename types::value_type;
  using init_type = typename types::init_type;
  using typename Base::Node;
  using typename Base::Leaf;
  using typename Base::node_ptr;
  using typename Base::leaf_ptr;
  using typename Base::key_type;
 protected:
  using Base::root_;
  using Base::dummy_;
  using Base::size_;
  std::array<hash_table_type, W+1> tb_;
  void _store_node(const int i, const key_type& x, node_ptr u) {
    tb_[i].emplace(W-i < (int)sizeof(key_type)*8 ? (x >> (W-i)) : 0, u);
  }
  void _init() override {
    for (auto& t:tb_) t.clear();
    Base::_init();
  }
 public:
  XFastTrieBase() : Base() {}
  XFastTrieBase(const XFastTrieBase& rhs) {
    Base::operator=(rhs);
  }
  XFastTrieBase& operator=(const XFastTrieBase& rhs) {
    Base::operator=(rhs);
  }
  XFastTrieBase(XFastTrieBase&& rhs) noexcept {
    Base::operator=(std::move(rhs));
  }
  XFastTrieBase& operator=(XFastTrieBase&& rhs) noexcept {
    Base::operator=(std::move(rhs));
  }
  template<typename InputIt>
  explicit XFastTrieBase(InputIt begin, InputIt end) {
    Base::_insert_init(begin, end);
  }
  using iterator = typename Base::iterator;
  using Base::end;
 protected:
  node_ptr create_node_at(const key_type& x, int i) override {
    auto u = Base::create_node_at(x, i);
    _store_node(i, x, u);
    return u;
  }
  leaf_ptr create_leaf_at(const key_type& x, const init_type& value) override {
    auto l = Base::create_leaf_at(x, value);
    _store_node(W, x, std::static_pointer_cast<Node>(l));
    return l;
  }
  leaf_ptr create_leaf_at(const key_type& x, init_type&& value) override {
    auto l = Base::create_leaf_at(x, std::move(value));
    _store_node(W, x, std::static_pointer_cast<Node>(l));
    return l;
  }
  void erase_node_at(const key_type& x, int i, node_ptr u) override {
    Base::erase_node_at(x, i, u);
    auto it = tb_[i].find(W-i < (int)sizeof(key_type)*8 ? (x >> (W-i)) : 0);
    assert(it != tb_[i].end());
    assert(it->second == u);
    tb_[i].erase(it);
  }
  std::pair<int, node_ptr> _traverse(const key_type& key, 
                                     int depth = 0, 
                                     node_ptr root = nullptr) const override {
    key_type x = key;
    int l = depth, h = W+1;
    node_ptr u = !root ? root_ : root;
    while (l+1 < h) {
      int i = l+(h-l)/2;
      auto p = W-i < (int)sizeof(key_type)*8 ? (x >> (W-i)) : 0;
      auto it = tb_[i].find(p);
      if (it != tb_[i].end()) {
        l = i;
        u = it->second;
      } else {
        h = i;
      }
    }
    return std::make_pair(l, u);
  }
  iterator _find(const key_type& x) const override {
    auto it = tb_[W].find(x);
    if (it != tb_[W].end())
      return iterator(std::static_pointer_cast<Leaf>(it->second));
    else
      return end();
  }
  using Base::_insert;
  std::pair<int, node_ptr> climb_to_lca(leaf_ptr l, key_type x) override {
    key_type m = x ^ types::key_of(l->v);
    if (m == 0)
      return std::make_pair(W, std::static_pointer_cast<Node>(l));
    int h = bm::clz(m) - (64 - W);
    key_type y = W-h < (int)sizeof(key_type)*8 ? (x >> (W-h)) : 0;
    assert(tb_[h].count(y));
    node_ptr f = tb_[h][y];
    return std::make_pair(h, f);
  }
  using Base::_emplace_hint;
  using Base::_erase;
  bool _erase(const key_type& key) override {
    auto it = tb_[W].find(key);
    if (it != tb_[W].end()) {
      Base::_erase_from_leaf(key, std::static_pointer_cast<Leaf>(it->second));
      return true;
    } else {
      return false;
    }
  }
};

#line 137 "include/mtl/xft.hpp"

template<typename T, typename V, uint8_t W = sizeof(T)*8,
    class HashTable = XFT_HASH_TABLE_TYPE(XFT_DEFAULT_HASH_TABLE, T, V, W)>
using XFastTrie = traits::MapTraits<XFastTrieBase<T, V, W, HashTable>>;
template<typename T, uint8_t W = sizeof(T)*8,
    class HashTable = XFT_HASH_TABLE_TYPE(XFT_DEFAULT_HASH_TABLE, T, void, W)>
using XFastTrieSet = traits::SetTraits<XFastTrieBase<T, void, W, HashTable>>;
template<typename T, typename V, uint8_t W = sizeof(T)*8,
    class HashTable = XFT_HASH_TABLE_TYPE(XFT_DEFAULT_HASH_TABLE, T, V, W)>
using XFastTrieMap = XFastTrie<T, V, W, HashTable>;
#line 1 "test/standalone/set_test.hpp"
#include <iostream>
#include <vector>
#line 6 "test/standalone/set_test.hpp"
#include <set>

namespace mtl {

using std::cout;
using std::cerr;
using std::endl;

template<class Map>
void map_emplace_test() {
  using key_type = typename Map::key_type;
  using mapped_type = typename Map::mapped_type;
  Map s;
  s.emplace(std::make_pair(key_type(), mapped_type()));
  s.emplace(key_type(), mapped_type());
}

template<class Set, int Max = (int)4e5, bool Shuffle = true>
void integer_set_test() {
  std::vector<int> values;
  while (values.empty()) {
    for (int i = 0; i < Max; i++)
      if (rand()%4 == 0)
        values.push_back(i);
  }
  int n = values.size();
  auto insertions = values;
  if constexpr (Shuffle)
    std::random_shuffle(insertions.begin(), insertions.end());

  Set S(insertions.begin(), insertions.end());

  if (values != std::vector<int>(S.begin(), S.end())) {
    cout << "after insert order broken" << endl;
    exit(EXIT_FAILURE);
  }

//  S.print_for_debug();
  int target = -1;
  int pred = -1;
  int succ = values[0];
  int k = 0;
  auto log = [&]() {
    std::cout << pred << ' ' << target << ' ' << succ << std::endl;
  };
  for (int i = 0; i < Max; i++) {
    if (k < n and values[k] == i) {
      target = values[k];
      pred = k-1 >= 0 ? values[k-1] : -1;
      succ = k+1 < n ? values[k+1] : -1;
      k++;
    } else {
      pred = k-1 >= 0 ? values[k-1] : -1;
      target = -1;
    }

    auto fit = S.find(i);
    if (fit != S.end()) {
      if ((int)*fit != i) {
        std::cout << "find: " << i << std::endl;
        log();
        exit(EXIT_FAILURE);
      }
    } else {
      if (target != -1) {
        log();
        exit(EXIT_FAILURE);
      }
    }

    auto sucit = S.upper_bound(i);
    if (sucit != S.end()) {
      if ((int)*sucit != succ) {
        std::cout << "succ: " << *sucit << std::endl;
        log();
        exit(EXIT_FAILURE);
      }
    } else {
      if (succ != -1) {
        log();
        exit(EXIT_FAILURE);
      }
    }

    auto predit = S.lower_bound(i);
    if (predit != S.begin()) {
      --predit;
      if ((int)*predit != pred) {
        std::cout << "pred: " << *predit << std::endl;
        log();
        exit(EXIT_FAILURE);
      }
    } else {
      if (pred != -1) {
        log();
        exit(EXIT_FAILURE);
      }
    }
  }

  int size = n;
  if ((int) S.size() != size) {
    std::cout << S.size() << ' ' << size<< std::endl;
    log();
    exit(EXIT_FAILURE);
  }

  for (int v : insertions) {
    auto f = S.find(v);
    assert(f != S.end());
    auto p = f;
    auto m = std::next(f);
    for (int i = 0; i < 2 and p != S.begin(); i++)
      --p;
    for (int i = 0; i < 2 and m != S.end(); i++)
      ++m;
    std::vector<unsigned> o(p,m);
    o.erase(find(o.begin(), o.end(), v));
    S.erase(v);
    size--;

    {
      auto lb = S.lower_bound(v);
      auto p = lb, m = lb;
      for (int i = 0; i < 2 and p != S.begin(); i++)
        --p;
      for (int i = 0; i < 2 and m != S.end(); i++)
        ++m;
      if (o != std::vector<unsigned>(p,m)) {
        std::cout << n-size<<" after erase "<<v<<" order broken " << endl;
        for (auto v:o)
          cerr<<v<<' ';
        cerr<<endl;
        for (auto it = p; it != m; ++it) {
          cerr<<*it<<' ';
        }
        cerr<<endl;
        exit(EXIT_FAILURE);
      }
    }
    if ((int) S.size() != size) {
      std::cout << S.size() << ' ' << size<< std::endl;
      exit(EXIT_FAILURE);
    }
  }
  cerr<<"integer_set_test ok"<<endl;
}

}
#line 4 "test/standalone/xft_test.cpp"

int main() {
  mtl::integer_set_test<XFastTrieSet<unsigned, 20>, 1<<20>();
  mtl::integer_set_test<XFastTrieSet<unsigned, 20>, 1<<20, false>();
  mtl::map_emplace_test<XFastTrieMap<unsigned, std::vector<int>>>();
  std::cout << "OK" << std::endl;
}
Back to top page