matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:heavy_check_mark: test/standalone/binary_trie_test.cpp

Code

#define STANDALONE
#include "include/mtl/binary_trie.hpp"
#include <iostream>
#include <vector>

void test_constructor() {
    BinaryTrie<uint32_t, uint32_t> trie;
    assert(trie.empty());
    assert(trie.size() == 0);
}

void test_insert() {
    BinaryTrie<uint32_t, uint32_t> trie;
    auto result1 = trie.insert(std::make_pair(5, 10));
    auto result2 = trie.insert(std::make_pair(5, 20));
    assert(trie.size() == 1);
    assert(result1.second);
    assert(!result2.second);
    assert(result1.first->second == 10);
}

void test_find() {
    BinaryTrie<uint32_t, uint32_t> trie;
    trie.insert(std::make_pair(5, 10));
    trie.insert(std::make_pair(3, 20));

    auto result1 = trie.find(5);
    auto result2 = trie.find(3);
    auto result3 = trie.find(7);

    assert(result1 != trie.end());
    assert(result1->second == 10);
    assert(result2 != trie.end());
    assert(result2->second == 20);
    assert(result3 == trie.end());
}

void test_erase() {
    BinaryTrie<uint32_t, uint32_t> trie;
    trie.insert(std::make_pair(5, 10));
    trie.insert(std::make_pair(3, 20));

    bool result1 = trie.erase(5);
    bool result2 = trie.erase(7);

    assert(result1);
    assert(!result2);
    assert(trie.size() == 1);
    assert(trie.find(5) == trie.end());
    assert(trie.find(3) != trie.end());
}

void test_range_constructor() {
    std::vector<std::pair<uint32_t, uint32_t>> values = {
        {1, 10},
        {3, 30},
        {5, 50},
    };

    BinaryTrie<uint32_t, uint32_t> trie(values.begin(), values.end());
    assert(trie.size() == 3);
    assert(trie.find(1) != trie.end());
    assert(trie.find(3) != trie.end());
    assert(trie.find(5) != trie.end());
}

int main() {
    test_constructor();
    test_insert();
    test_find();
    test_erase();
    test_range_constructor();

    std::cout << "All tests passed!" << std::endl;
    return 0;
}
#line 1 "test/standalone/binary_trie_test.cpp"
#define STANDALONE
#line 2 "include/mtl/traits/set_traits.hpp"
#include <cstddef>
#include <initializer_list>
#include <type_traits>
#include <iterator>

namespace traits {

template<typename T, typename M>
struct AssociativeArrayDefinition {
  using key_type = T;
  using mapped_type = M;
  using value_type = std::pair<T const, M>;
  using raw_key_type = typename std::remove_const<T>::type;
  using raw_mapped_type = typename std::remove_const<M>::type;
  using init_type = std::pair<raw_key_type, raw_mapped_type>;
  using moved_type = std::pair<raw_key_type&&, raw_mapped_type&&>;
  template<class K, class V>
  static key_type const& key_of(std::pair<K,V> const& kv) {
    return kv.first;
  }
};
template<typename T>
struct AssociativeArrayDefinition<T, void> {
  using key_type = T;
  using value_type = T;
  using init_type = T;
  static key_type const& key_of(value_type const& k) { return k; }
};

template<class T, typename = std::void_t<>>
struct get_const_iterator {
  using base = typename T::iterator;
  struct type : base {
    type(const base& r) : base(r) {}
    type(base&& r) : base(std::move(r)) {}
  };
};
template<class T>
struct get_const_iterator<T, std::void_t<typename T::const_iterator>> {
  using type = typename T::const_iterator;
};

#if __cplusplus >= 202002L
#include <concepts>
template<class M>
concept IsAssociativeArray = requires (M m) {
  typename M::key_type;
  typename M::value_type;
  typename M::iterator;
  {m.size()} -> std::convertible_to<size_t>;
  {m.empty()} -> std::same_as<bool>;
  {m.clear()};
  {m.begin()} -> std::same_as<typename M::iterator>;
  {m.end()} -> std::same_as<typename M::iterator>;
};
#endif

template<class Base>
#if __cplusplus >= 202002L
requires IsAssociativeArray<Base>
#endif
class SetTraitsBase : public Base {
 public:
  using key_type = typename Base::key_type;
  using value_type = typename Base::value_type;
  using init_type = typename Base::init_type;
  using iterator = typename Base::iterator;
  SetTraitsBase() = default;
  template<typename InputIt>
  explicit SetTraitsBase(InputIt begin, InputIt end) : Base(begin, end) {
    static_assert(std::is_convertible<typename std::iterator_traits<InputIt>::value_type, value_type>::value, "");
  }
  SetTraitsBase(std::initializer_list<value_type> init) : Base(init.begin(), init.end()) {}
  using Base::size;
  bool empty() const { return size() == 0; }
  using Base::clear;
  using const_iterator = typename get_const_iterator<Base>::type;
  iterator begin() {
    return Base::begin();
  }
  iterator end() {
    return Base::end();
  }
  const_iterator begin() const {
    return const_iterator(Base::begin());
  }
  const_iterator end() const {
    return const_iterator(Base::end());
  }
  const_iterator cbegin() const {
    return begin();
  }
  const_iterator cend() const {
    return end();
  }
  using reverse_iterator = std::reverse_iterator<iterator>;
  using reverse_const_iterator = std::reverse_iterator<const_iterator>;
  reverse_iterator rbegin() {
    return std::make_reverse_iterator(end());
  }
  reverse_iterator rend() {
    return std::make_reverse_iterator(begin());
  }
  reverse_const_iterator rbegin() const {
    return std::make_reverse_iterator(end());
  }
  reverse_const_iterator rend() const {
    return std::make_reverse_iterator(begin());
  }
  reverse_const_iterator crbegin() const {
    return rbegin();
  }
  reverse_const_iterator crend() const {
    return rend();
  }
  template<class Key>
  iterator lower_bound(const Key& x) const {
    return Base::_lower_bound(x);
  }
  iterator lower_bound(const key_type& x) const {
    return Base::_lower_bound(x);
  }
  template<class Key>
  iterator upper_bound(const Key& x) const {
    return Base::_upper_bound(x);
  }
  iterator upper_bound(const key_type& x) const {
    return Base::_upper_bound(x);
  }
  template<class Key>
  iterator find(const Key& x) {
    return Base::_find(x);
  }
  iterator find(const key_type& x) {
    return Base::_find(x);
  }
  template<class Key>
  const_iterator find(const Key& x) const {
    return Base::_find(x);
  }
  const_iterator find(const key_type& x) const {
    return Base::_find(x);
  }
  template<class Key>
  size_t count(const Key& x) const {
    return find(x) != end();
  }
  size_t count(const key_type& x) const {
    return find(x) != end();
  }
  std::pair<iterator, bool> insert(const init_type& v) {
    return Base::_insert(v);
  }
  std::pair<iterator, bool> insert(init_type&& v) {
    return Base::_insert(std::move(v));
  }
  template<typename=void>
  std::pair<iterator, bool> insert(const value_type& v) {
    return Base::_insert(v);
  }
  template<typename=void>
  std::pair<iterator, bool> insert(value_type&& v) {
    return Base::_insert(std::move(v));
  }
  template<class... Args>
  std::pair<iterator, bool> emplace(Args&&... args) {
    using emplace_type = typename std::conditional<
        std::is_constructible<init_type, Args...>::value,
        init_type,
        value_type
    >::type;
    return Base::_insert(emplace_type(std::forward<Args>(args)...));
  }
  template<class... Args>
  iterator emplace_hint(const_iterator hint, Args&&... args) {
    using emplace_type = typename std::conditional<
        std::is_constructible<init_type, Args...>::value,
        init_type,
        value_type
    >::type;
    return Base::_emplace_hint(hint, emplace_type(std::forward<Args>(args)...));
  }
  size_t erase(const key_type& x) {
    return Base::_erase(x);
  }
  iterator erase(iterator it) {
    return Base::_erase(it);
  }
  iterator erase(const_iterator it) {
    return Base::_erase(it);
  }
};

template<typename Base>
using SetTraits = SetTraitsBase<Base>;

template<typename Base>
class MapTraits : public SetTraitsBase<Base> {
  using SBase = SetTraitsBase<Base>;
 public:
  using typename SBase::key_type;
  using typename SBase::mapped_type;
  using typename SBase::value_type;
  using reference = mapped_type&;
  MapTraits() = default;
  template<typename InputIt>
  explicit MapTraits(InputIt begin, InputIt end) : SBase(begin, end) {}
  MapTraits(std::initializer_list<value_type> init) : SBase(init) {}
  template<typename Key>
  reference operator[](Key&& x) {
    auto i = SBase::lower_bound(x);
    if (i == SBase::end() || x < i->first) {
      i = SBase::emplace_hint(i, std::forward<Key>(x), mapped_type());
    }
    return i->second;
  }
  reference operator[](const key_type& x) {
    auto i = SBase::lower_bound(x);
    if (i == SBase::end() || x < i->first) {
      i = SBase::emplace_hint(i, x, mapped_type());
    }
    return i->second;
  }
  reference operator[](key_type&& x) {
    auto i = SBase::lower_bound(x);
    if (i == SBase::end() || x < i->first) {
      i = SBase::emplace_hint(i, std::move(x), mapped_type());
    }
    return i->second;
  }
};

} // namespace traits
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#include <cassert>
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif

namespace bm {

/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
  x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
  x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
  x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
  return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
  return std::popcount(x);
#else
  return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
  return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_zero(x);
#else
  return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
  return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_one(x);
#else
  return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
  return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
  return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
  constexpr uint64_t hmask = 0x8080808080808080ull;
  constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
  auto a = x & hmask;
  auto b = x & lmask;
  b = hmask - b;
  b = ~b;
  auto c = (a | b) & hmask;
  c *= 0x0002040810204081ull;
  return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
  uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
  return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
  if (x & 0x80)
    return 8;
  uint64_t p = uint64_t(x) * 0x0101010101010101ull;
  p -= 0x8040201008040201ull;
  p = ~p & 0x8080808080808080ull;
  p = (p >> 7) * 0x0101010101010101ull;
  p >>= 56;
  return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
  assert(x != 0);
  return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
  assert(x != 0);
  return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_zero(x);
#else
  if (x == 0)
    return 64;
  auto i = mssb8(summary(x));
  auto j = mssb8(bextr(x, 8 * i, 8));
  return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
  return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_one(x);
#else
  return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
  return clo(x);
}

inline constexpr unsigned clz8(uint8_t x) {
  return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
  x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
  x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
  x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
  x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
  x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
  x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
  return x;
}

/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::has_single_bit(x);
#else
  return x != 0 && (x & (x - 1)) == 0;
#endif
}

/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::bit_width(x);
#else
  return 64 - clz(x);
#endif
}

/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_ceil(x);
#else
  if (x == 0) return 1;
  return 1ull << bit_width(x - 1);
#endif
}

/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_floor(x);
#else
  if (x == 0) return 0;
  return 1ull << (bit_width(x) - 1);
#endif
}

} // namespace bm
#line 4 "include/mtl/binary_trie.hpp"
#include <array>
#include <memory>
#line 9 "include/mtl/binary_trie.hpp"
#include <algorithm>
#line 11 "include/mtl/binary_trie.hpp"

template<typename T, typename M,
    int8_t W = sizeof(T) * 8>
class BinaryTrieBase : public traits::AssociativeArrayDefinition<T, M> {
  static_assert(std::is_unsigned<T>::value, "");
 public:
  using types = traits::AssociativeArrayDefinition<T, M>;
  using key_type = typename types::key_type;
  using value_type = typename types::value_type;
  using init_type = typename types::init_type;
  struct Node;
  using node_ptr = std::shared_ptr<Node>;
  using node_weak_ptr = std::weak_ptr<Node>;
  struct Leaf;
  using leaf_ptr = std::shared_ptr<Leaf>;
  struct Node {
    std::array<node_ptr, 2> c;
    leaf_ptr jump;
    node_weak_ptr parent;
    Node() = default;
    node_ptr& left() { return c[0]; }
    node_ptr& right()  { return c[1]; }
  };
  struct Leaf : Node {
    value_type v;
    Leaf() = default;
    Leaf(const value_type& v) : Node(), v(v) {}
    Leaf(value_type&& v) : Node(), v(std::forward<value_type>(v)) {}
    key_type key() const {
      return types::key_of(v);
    }
    using Node::c;
    leaf_ptr prev() const {
      return std::static_pointer_cast<Leaf>(c[0]);
    }
    leaf_ptr next() const {
      return std::static_pointer_cast<Leaf>(c[1]);
    }
    void set_prev(leaf_ptr l) {
      c[0] = std::static_pointer_cast<Node>(l);
    }
    void set_next(leaf_ptr l) {
      c[1] = std::static_pointer_cast<Node>(l);
    }
  };
 protected:
  node_ptr root_;
  leaf_ptr dummy_;
  size_t size_;
  virtual void _init() {
    root_ = create_node_at(0, 0);
    dummy_ = std::make_shared<Leaf>();
    root_->jump = dummy_;
    dummy_->set_next(dummy_);
    dummy_->set_prev(dummy_);
    size_ = 0;
  }
  void _deinit() {
    root_ = nullptr;
    auto u = dummy_->next();
    dummy_->set_next(nullptr);
    u->set_prev(nullptr);
    while (u != dummy_) {
      auto n = u->next();
      u->set_next(nullptr);
      n->set_prev(nullptr);
      u = n;
    }
    dummy_ = nullptr;
  }
 public:
  BinaryTrieBase() {
    _init();
  }
  BinaryTrieBase(const BinaryTrieBase& rhs) {
    _insert_init(rhs.begin(), rhs.end());
  }
  virtual BinaryTrieBase& operator=(const BinaryTrieBase& rhs) {
    _deinit();
    _insert_init(rhs.begin(), rhs.end());
    return *this;
  }
  BinaryTrieBase(BinaryTrieBase&&) noexcept = default;
  virtual BinaryTrieBase& operator=(BinaryTrieBase&& rhs) noexcept {
    _deinit();
    root_ = std::move(rhs.root_);
    dummy_ = std::move(rhs.dummy_);
    size_ = std::move(rhs.size_);
    return *this;
  }
  virtual ~BinaryTrieBase() {
    _deinit();
  }
 protected:
  template<class InputIt>
  void _insert_init(InputIt begin, InputIt end) {
    static_assert(std::is_convertible<typename std::iterator_traits<InputIt>::value_type, value_type>::value, "");
    _init();
    if (begin == end) return;
    if (!std::is_sorted(begin, end, [](auto& l, auto& r) {
      return types::key_of(l) < types::key_of(r);
    })) {
      for (auto it = begin; it != end; ++it)
        _insert(*it);
      return;
    }
    auto push_link = [&](leaf_ptr l) {
      auto b = dummy_->prev();
      l->set_prev(b);
      l->set_next(dummy_);
      l->prev()->set_next(l);
      l->next()->set_prev(l);
    };
    std::array<node_ptr, W> us{};
    auto grow = [&](key_type x, int k, leaf_ptr l) {
      for (int i = k; i < W-1; i++) {
        us[i+1] = create_node_at(x, i+1);
        int c = (x >> (W-i-1)) & 1;
        us[i]->c[c] = us[i+1];
        us[i+1]->parent = us[i];
        us[i+1]->jump = l;
      }
      int c = x & 1;
      us[W-1]->c[c] = l;
      l->parent = us[W-1];
    };
    us[0] = root_;
    key_type x = types::key_of(*begin);
    auto l = create_leaf_at(x, *begin);
    push_link(l);
    us[0]->jump = l;
    grow(x, 0, l);
    size_t sz = 1;
    for (auto it = std::next(begin); it != end; ++it) {
      key_type px = x;
      x = types::key_of(*it);
      auto m = x ^ px;
      if (m == 0) continue;
//      [[assume(m != 0)]]
      int k = W-1;
      while (m > 1) {
        m >>= 1;
        --k;
      }
      l = create_leaf_at(x, *it);
      push_link(l);
      for (int i = 0; i < k; i++)
        if (!us[i]->c[1]) us[i]->jump = l;
      us[k]->jump = nullptr;
      grow(x, k, l);
      ++sz;
    }
    size_ = sz;
  }
 public:
  template<typename InputIt>
  explicit BinaryTrieBase(InputIt begin, InputIt end) {
    _insert_init(begin, end);
  }
  size_t size() const {
    return size_;
  }
  bool empty() const { return size() == 0; }
  void clear() {
    _deinit();
    _init();
  }
 protected:
  template<bool> struct iterator_base;
 public:
  using iterator = iterator_base<false>;
  using const_iterator = iterator_base<true>;
  iterator begin() {
    return iterator(dummy_->next());
  }
  iterator end() {
    return iterator(dummy_);
  }
  const_iterator begin() const {
    return const_iterator(dummy_->next());
  }
  const_iterator end() const {
    return const_iterator(dummy_);
  }
  template<class Rule>
  const_iterator traverse(Rule rule) const {
    auto u = root_;
    for (int i = 0; i < W; i++) {
      auto l = (bool)u->c[0];
      auto r = (bool)u->c[1];
      auto c = rule(W-1-i, l, r);
      u = u->c[c];
    }
    return const_iterator(std::static_pointer_cast<Leaf>(u));
  }
 protected:
  virtual std::pair<int, node_ptr> _traverse(const key_type& key, 
                                             int depth = 0, 
                                             node_ptr root = nullptr) const {
    int i, c;
    key_type x = key;
    auto u = !root ? root_ : root;
    for (i = depth; i < W; i++) {
      c = (x >> (W-i-1)) & 1;
      if (!u->c[c]) break;
      u = u->c[c];
    }
    return std::make_pair(i, u);
  }
  iterator _lower_bound(const key_type& x) const {
    auto reached = _traverse(x);
    int i = reached.first;
    node_ptr u = reached.second;
    if (i == W) return iterator(std::static_pointer_cast<Leaf>(u));
    auto l = (((x >> (W-i-1)) & 1) == 0) ? u->jump : u->jump->next();
    return iterator(l);
  }
  iterator _upper_bound(const key_type& x) const {
    auto it = _lower_bound(x);
    if (types::key_of(*it) == x)
      ++it;
    return it;
  }
  virtual iterator _find(const key_type& x) const {
    auto reached = _traverse(x);
    int i = reached.first;
    node_ptr u = reached.second;
    if (i == W)
      return iterator(std::static_pointer_cast<Leaf>(u));
    else
      return end();
  }
  virtual node_ptr create_node_at(const key_type&, int) {
    return std::make_shared<Node>();
  }
  virtual leaf_ptr create_leaf_at(const key_type&, const init_type& value) {
    return std::make_shared<Leaf>(value);
  }
  virtual leaf_ptr create_leaf_at(const key_type&, init_type&& value) {
    return std::make_shared<Leaf>(std::move(value));
  }
  template<typename Value>
  iterator _emplace_impl(key_type x, int height, node_ptr forked, Value&& value) {
    assert(height < W);
    int i = height;
    node_ptr u = forked;
    auto f = u;
    int c = (x >> (W-i-1)) & 1;
    auto fc = c;
    auto fi = i;
    auto pred = c == 1 ? u->jump : u->jump->prev();
    u->jump = nullptr;
    auto l = create_leaf_at(x, std::forward<Value>(value));
    l->set_prev(pred);
    l->set_next(pred->next());
    l->prev()->set_next(l);
    l->next()->set_prev(l);
    for (; i < W-1; i++) {
      c = (x >> (W-i-1)) & 1;
      assert(!u->c[c]);
      u->c[c] = create_node_at(x, i+1);
      u->c[c]->parent = u;
      u->c[c]->jump = l;
      u = u->c[c];
    }
    {
      c = (x >> (W-i-1)) & 1;
      u->c[c] = l;
      u->c[c]->parent = u;
    }
    if (f == root_) [[unlikely]] {
      f->jump = l;
    } else [[likely]] {
      auto v = f->parent.lock();
      fi--;
      while (v) {
        c = x >> (W-fi-1) & 1;
        if (c != fc and !v->jump)
          break;
        if (!v->c[fc])
          v->jump = l;
        v = v->parent.lock();
        fi--;
      }
    }
    size_++;
    return iterator(l);
  }
  template<typename Value>
  std::pair<iterator, bool> _insert(Value&& value) {
    static_assert(std::is_convertible<Value, value_type>::value, "");
    key_type x = types::key_of(value);
    auto reached = _traverse(x);
    int i = reached.first;
    node_ptr u = reached.second;
    if (i == W)
      return std::make_pair(iterator(std::static_pointer_cast<Leaf>(u)), false);
    return std::make_pair(_emplace_impl(x, i, u, std::forward<Value>(value)), true);
  }
  virtual std::pair<int, node_ptr> climb_to_lca(leaf_ptr l, key_type x) {
    key_type m = x ^ types::key_of(l->v);
    if (m == 0)
      return std::make_pair(W, std::static_pointer_cast<Node>(l));
    int h = bm::clz(m) - (64 - W);
    node_ptr f = std::static_pointer_cast<Node>(l);
    for (int i = W; i > h; i--)
      f = f->parent.lock();
    return std::make_pair(h, f);
  }
  template<class Value>
  iterator _emplace_hint(const_iterator hint, Value&& value) {
    key_type x = types::key_of(value);
    if (empty())
      return _emplace_impl(x, 0, root_, std::forward<Value>(value));
    if (hint == end())
      --hint;
    int h;
    node_ptr f;
    std::tie(h, f) = climb_to_lca(hint.ptr_, x);
    std::tie(h, f) = _traverse(x, h, f);
    if (h == W)
      return iterator(std::static_pointer_cast<Leaf>(f));
    return _emplace_impl(x, h, f, std::forward<Value>(value));
  }

  virtual void erase_node_at(const key_type&, int, node_ptr) {}
  virtual bool _erase(const key_type& key) {
    auto it = _find(key);
    if (it != end()) {
      _erase_from_leaf(types::key_of(*it), it.ptr_);
      return true;
    } else {
      return false;
    }
  }
  template<typename Key>
  iterator _erase_from_leaf(Key&& key, leaf_ptr l) {
    static_assert(std::is_convertible<Key, key_type>::value, "");
    key_type x = std::forward<Key>(key);
    assert(x == l->key());
    l->prev()->set_next(l->next());
    l->next()->set_prev(l->prev());
    int i,c;
    auto v = std::static_pointer_cast<Node>(l);
    for (i = W-1; i >= 0; i--) {
      erase_node_at(x, i+1, v);
      v = v->parent.lock();
      c = (x >> (W-i-1)) & 1;
      v->c[c] = nullptr;
      if (v->c[c^1]) break;
    }
    auto nj = c ? l->prev() : l->next();
    auto fc = c;
    v->jump = nj;
    v = v->parent.lock();
    i--;
    for (; i >= 0; i--) {
      assert(v);
      c = (x >> (W-i-1)) & 1;
      if (c != fc) {
        if (!v->jump) break;
        v->jump = nj;
      }
      v = v->parent.lock();
    }
    size_--;
    return iterator(l->next());
  }
  iterator iterator_remove_const(const const_iterator& it) {
    return iterator(it.ptr_);
  }
  iterator iterator_remove_const(const_iterator&& it) {
    return iterator(std::move(it.ptr_));
  }
  iterator _erase(iterator it) {
    if (it == end()) return it;
    return _erase_from_leaf(types::key_of(*it), it.ptr_);
  }
  iterator _erase(const_iterator it) {
    if (it == end()) return iterator_remove_const(it);
    return _erase_from_leaf(types::key_of(*it), it.ptr_);
  }
 protected:
  template<bool Const>
  struct iterator_base {
    using difference_type = ptrdiff_t;
    using value_type = BinaryTrieBase::value_type;
    using pointer = typename std::conditional<Const,
                                              const value_type*,
                                              value_type*>::type;
    using reference = typename std::conditional<Const,
                                                const value_type&,
                                                value_type&>::type;
    using iterator_category = std::bidirectional_iterator_tag;
    leaf_ptr ptr_;
    iterator_base(leaf_ptr p) : ptr_(p) {}
    template<bool C>
    iterator_base(const iterator_base<C>& rhs) : ptr_(rhs.ptr_) {}
    template<bool C>
    iterator_base& operator=(const iterator_base<C>& rhs) {
      ptr_ = rhs.ptr_;
    }
    template<bool C>
    iterator_base(iterator_base<C>&& rhs) : ptr_(std::move(rhs.ptr_)) {}
    template<bool C>
    iterator_base& operator=(iterator_base<C>&& rhs) {
      ptr_ = std::move(rhs.ptr_);
    }
    reference operator*() {
      return ptr_->v;
    }
    pointer operator->() {
      return &(ptr_->v);
    }
    template<bool C>
    bool operator==(const iterator_base<C>& rhs) const {
      return ptr_ == rhs.ptr_;
    }
    template<bool C>
    bool operator!=(const iterator_base<C>& rhs) const {
      return !operator==(rhs);
    }
    iterator_base& operator++() {
      ptr_ = ptr_->next();
      return *this;
    }
    iterator_base operator++(int) const {
      iterator_base ret = *this;
      operator++();
      return ret;
    }
    iterator_base& operator--() {
      ptr_ = ptr_->prev();
      return *this;
    }
    iterator_base operator--(int) const {
      iterator_base ret = *this;
      operator--();
      return ret;
    }
  };
};

template<typename T, typename V, uint8_t W = sizeof(T)*8>
using BinaryTrie = traits::MapTraits<BinaryTrieBase<T, V, W>>;
template<typename T, uint8_t W = sizeof(T)*8>
using BinaryTrieSet = traits::SetTraits<BinaryTrieBase<T, void, W>>;
template<typename T, typename V, uint8_t W = sizeof(T)*8>
using BinaryTrieMap = BinaryTrie<T, V, W>;
#line 3 "test/standalone/binary_trie_test.cpp"
#include <iostream>
#include <vector>

void test_constructor() {
    BinaryTrie<uint32_t, uint32_t> trie;
    assert(trie.empty());
    assert(trie.size() == 0);
}

void test_insert() {
    BinaryTrie<uint32_t, uint32_t> trie;
    auto result1 = trie.insert(std::make_pair(5, 10));
    auto result2 = trie.insert(std::make_pair(5, 20));
    assert(trie.size() == 1);
    assert(result1.second);
    assert(!result2.second);
    assert(result1.first->second == 10);
}

void test_find() {
    BinaryTrie<uint32_t, uint32_t> trie;
    trie.insert(std::make_pair(5, 10));
    trie.insert(std::make_pair(3, 20));

    auto result1 = trie.find(5);
    auto result2 = trie.find(3);
    auto result3 = trie.find(7);

    assert(result1 != trie.end());
    assert(result1->second == 10);
    assert(result2 != trie.end());
    assert(result2->second == 20);
    assert(result3 == trie.end());
}

void test_erase() {
    BinaryTrie<uint32_t, uint32_t> trie;
    trie.insert(std::make_pair(5, 10));
    trie.insert(std::make_pair(3, 20));

    bool result1 = trie.erase(5);
    bool result2 = trie.erase(7);

    assert(result1);
    assert(!result2);
    assert(trie.size() == 1);
    assert(trie.find(5) == trie.end());
    assert(trie.find(3) != trie.end());
}

void test_range_constructor() {
    std::vector<std::pair<uint32_t, uint32_t>> values = {
        {1, 10},
        {3, 30},
        {5, 50},
    };

    BinaryTrie<uint32_t, uint32_t> trie(values.begin(), values.end());
    assert(trie.size() == 3);
    assert(trie.find(1) != trie.end());
    assert(trie.find(3) != trie.end());
    assert(trie.find(5) != trie.end());
}

int main() {
    test_constructor();
    test_insert();
    test_find();
    test_erase();
    test_range_constructor();

    std::cout << "All tests passed!" << std::endl;
    return 0;
}
Back to top page