This documentation is automatically generated by competitive-verifier/competitive-verifier
#define STANDALONE
#include "include/mtl/binary_trie.hpp"
#include <iostream>
#include <vector>
void test_constructor() {
BinaryTrie<uint32_t, uint32_t> trie;
assert(trie.empty());
assert(trie.size() == 0);
}
void test_insert() {
BinaryTrie<uint32_t, uint32_t> trie;
auto result1 = trie.insert(std::make_pair(5, 10));
auto result2 = trie.insert(std::make_pair(5, 20));
assert(trie.size() == 1);
assert(result1.second);
assert(!result2.second);
assert(result1.first->second == 10);
}
void test_find() {
BinaryTrie<uint32_t, uint32_t> trie;
trie.insert(std::make_pair(5, 10));
trie.insert(std::make_pair(3, 20));
auto result1 = trie.find(5);
auto result2 = trie.find(3);
auto result3 = trie.find(7);
assert(result1 != trie.end());
assert(result1->second == 10);
assert(result2 != trie.end());
assert(result2->second == 20);
assert(result3 == trie.end());
}
void test_erase() {
BinaryTrie<uint32_t, uint32_t> trie;
trie.insert(std::make_pair(5, 10));
trie.insert(std::make_pair(3, 20));
bool result1 = trie.erase(5);
bool result2 = trie.erase(7);
assert(result1);
assert(!result2);
assert(trie.size() == 1);
assert(trie.find(5) == trie.end());
assert(trie.find(3) != trie.end());
}
void test_range_constructor() {
std::vector<std::pair<uint32_t, uint32_t>> values = {
{1, 10},
{3, 30},
{5, 50},
};
BinaryTrie<uint32_t, uint32_t> trie(values.begin(), values.end());
assert(trie.size() == 3);
assert(trie.find(1) != trie.end());
assert(trie.find(3) != trie.end());
assert(trie.find(5) != trie.end());
}
int main() {
test_constructor();
test_insert();
test_find();
test_erase();
test_range_constructor();
std::cout << "All tests passed!" << std::endl;
return 0;
}
#line 1 "test/standalone/binary_trie_test.cpp"
#define STANDALONE
#line 2 "include/mtl/traits/set_traits.hpp"
#include <cstddef>
#include <initializer_list>
#include <type_traits>
#include <iterator>
namespace traits {
template<typename T, typename M>
struct AssociativeArrayDefinition {
using key_type = T;
using mapped_type = M;
using value_type = std::pair<T const, M>;
using raw_key_type = typename std::remove_const<T>::type;
using raw_mapped_type = typename std::remove_const<M>::type;
using init_type = std::pair<raw_key_type, raw_mapped_type>;
using moved_type = std::pair<raw_key_type&&, raw_mapped_type&&>;
template<class K, class V>
static key_type const& key_of(std::pair<K,V> const& kv) {
return kv.first;
}
};
template<typename T>
struct AssociativeArrayDefinition<T, void> {
using key_type = T;
using value_type = T;
using init_type = T;
static key_type const& key_of(value_type const& k) { return k; }
};
template<class T, typename = std::void_t<>>
struct get_const_iterator {
using base = typename T::iterator;
struct type : base {
type(const base& r) : base(r) {}
type(base&& r) : base(std::move(r)) {}
};
};
template<class T>
struct get_const_iterator<T, std::void_t<typename T::const_iterator>> {
using type = typename T::const_iterator;
};
#if __cplusplus >= 202002L
#include <concepts>
template<class M>
concept IsAssociativeArray = requires (M m) {
typename M::key_type;
typename M::value_type;
typename M::iterator;
{m.size()} -> std::convertible_to<size_t>;
{m.empty()} -> std::same_as<bool>;
{m.clear()};
{m.begin()} -> std::same_as<typename M::iterator>;
{m.end()} -> std::same_as<typename M::iterator>;
};
#endif
template<class Base>
#if __cplusplus >= 202002L
requires IsAssociativeArray<Base>
#endif
class SetTraitsBase : public Base {
public:
using key_type = typename Base::key_type;
using value_type = typename Base::value_type;
using init_type = typename Base::init_type;
using iterator = typename Base::iterator;
SetTraitsBase() = default;
template<typename InputIt>
explicit SetTraitsBase(InputIt begin, InputIt end) : Base(begin, end) {
static_assert(std::is_convertible<typename std::iterator_traits<InputIt>::value_type, value_type>::value, "");
}
SetTraitsBase(std::initializer_list<value_type> init) : Base(init.begin(), init.end()) {}
using Base::size;
bool empty() const { return size() == 0; }
using Base::clear;
using const_iterator = typename get_const_iterator<Base>::type;
iterator begin() {
return Base::begin();
}
iterator end() {
return Base::end();
}
const_iterator begin() const {
return const_iterator(Base::begin());
}
const_iterator end() const {
return const_iterator(Base::end());
}
const_iterator cbegin() const {
return begin();
}
const_iterator cend() const {
return end();
}
using reverse_iterator = std::reverse_iterator<iterator>;
using reverse_const_iterator = std::reverse_iterator<const_iterator>;
reverse_iterator rbegin() {
return std::make_reverse_iterator(end());
}
reverse_iterator rend() {
return std::make_reverse_iterator(begin());
}
reverse_const_iterator rbegin() const {
return std::make_reverse_iterator(end());
}
reverse_const_iterator rend() const {
return std::make_reverse_iterator(begin());
}
reverse_const_iterator crbegin() const {
return rbegin();
}
reverse_const_iterator crend() const {
return rend();
}
template<class Key>
iterator lower_bound(const Key& x) const {
return Base::_lower_bound(x);
}
iterator lower_bound(const key_type& x) const {
return Base::_lower_bound(x);
}
template<class Key>
iterator upper_bound(const Key& x) const {
return Base::_upper_bound(x);
}
iterator upper_bound(const key_type& x) const {
return Base::_upper_bound(x);
}
template<class Key>
iterator find(const Key& x) {
return Base::_find(x);
}
iterator find(const key_type& x) {
return Base::_find(x);
}
template<class Key>
const_iterator find(const Key& x) const {
return Base::_find(x);
}
const_iterator find(const key_type& x) const {
return Base::_find(x);
}
template<class Key>
size_t count(const Key& x) const {
return find(x) != end();
}
size_t count(const key_type& x) const {
return find(x) != end();
}
std::pair<iterator, bool> insert(const init_type& v) {
return Base::_insert(v);
}
std::pair<iterator, bool> insert(init_type&& v) {
return Base::_insert(std::move(v));
}
template<typename=void>
std::pair<iterator, bool> insert(const value_type& v) {
return Base::_insert(v);
}
template<typename=void>
std::pair<iterator, bool> insert(value_type&& v) {
return Base::_insert(std::move(v));
}
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
using emplace_type = typename std::conditional<
std::is_constructible<init_type, Args...>::value,
init_type,
value_type
>::type;
return Base::_insert(emplace_type(std::forward<Args>(args)...));
}
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
using emplace_type = typename std::conditional<
std::is_constructible<init_type, Args...>::value,
init_type,
value_type
>::type;
return Base::_emplace_hint(hint, emplace_type(std::forward<Args>(args)...));
}
size_t erase(const key_type& x) {
return Base::_erase(x);
}
iterator erase(iterator it) {
return Base::_erase(it);
}
iterator erase(const_iterator it) {
return Base::_erase(it);
}
};
template<typename Base>
using SetTraits = SetTraitsBase<Base>;
template<typename Base>
class MapTraits : public SetTraitsBase<Base> {
using SBase = SetTraitsBase<Base>;
public:
using typename SBase::key_type;
using typename SBase::mapped_type;
using typename SBase::value_type;
using reference = mapped_type&;
MapTraits() = default;
template<typename InputIt>
explicit MapTraits(InputIt begin, InputIt end) : SBase(begin, end) {}
MapTraits(std::initializer_list<value_type> init) : SBase(init) {}
template<typename Key>
reference operator[](Key&& x) {
auto i = SBase::lower_bound(x);
if (i == SBase::end() || x < i->first) {
i = SBase::emplace_hint(i, std::forward<Key>(x), mapped_type());
}
return i->second;
}
reference operator[](const key_type& x) {
auto i = SBase::lower_bound(x);
if (i == SBase::end() || x < i->first) {
i = SBase::emplace_hint(i, x, mapped_type());
}
return i->second;
}
reference operator[](key_type&& x) {
auto i = SBase::lower_bound(x);
if (i == SBase::end() || x < i->first) {
i = SBase::emplace_hint(i, std::move(x), mapped_type());
}
return i->second;
}
};
} // namespace traits
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#include <cassert>
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif
namespace bm {
/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
return std::popcount(x);
#else
return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
return std::countr_zero(x);
#else
return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
return std::countr_one(x);
#else
return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
constexpr uint64_t hmask = 0x8080808080808080ull;
constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
auto a = x & hmask;
auto b = x & lmask;
b = hmask - b;
b = ~b;
auto c = (a | b) & hmask;
c *= 0x0002040810204081ull;
return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
if (x & 0x80)
return 8;
uint64_t p = uint64_t(x) * 0x0101010101010101ull;
p -= 0x8040201008040201ull;
p = ~p & 0x8080808080808080ull;
p = (p >> 7) * 0x0101010101010101ull;
p >>= 56;
return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
assert(x != 0);
return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
assert(x != 0);
return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
return std::countl_zero(x);
#else
if (x == 0)
return 64;
auto i = mssb8(summary(x));
auto j = mssb8(bextr(x, 8 * i, 8));
return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
return std::countl_one(x);
#else
return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
return clo(x);
}
inline constexpr unsigned clz8(uint8_t x) {
return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
return x;
}
/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
return std::has_single_bit(x);
#else
return x != 0 && (x & (x - 1)) == 0;
#endif
}
/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
return std::bit_width(x);
#else
return 64 - clz(x);
#endif
}
/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
return std::bit_ceil(x);
#else
if (x == 0) return 1;
return 1ull << bit_width(x - 1);
#endif
}
/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
return std::bit_floor(x);
#else
if (x == 0) return 0;
return 1ull << (bit_width(x) - 1);
#endif
}
} // namespace bm
#line 4 "include/mtl/binary_trie.hpp"
#include <array>
#include <memory>
#line 9 "include/mtl/binary_trie.hpp"
#include <algorithm>
#line 11 "include/mtl/binary_trie.hpp"
template<typename T, typename M,
int8_t W = sizeof(T) * 8>
class BinaryTrieBase : public traits::AssociativeArrayDefinition<T, M> {
static_assert(std::is_unsigned<T>::value, "");
public:
using types = traits::AssociativeArrayDefinition<T, M>;
using key_type = typename types::key_type;
using value_type = typename types::value_type;
using init_type = typename types::init_type;
struct Node;
using node_ptr = std::shared_ptr<Node>;
using node_weak_ptr = std::weak_ptr<Node>;
struct Leaf;
using leaf_ptr = std::shared_ptr<Leaf>;
struct Node {
std::array<node_ptr, 2> c;
leaf_ptr jump;
node_weak_ptr parent;
Node() = default;
node_ptr& left() { return c[0]; }
node_ptr& right() { return c[1]; }
};
struct Leaf : Node {
value_type v;
Leaf() = default;
Leaf(const value_type& v) : Node(), v(v) {}
Leaf(value_type&& v) : Node(), v(std::forward<value_type>(v)) {}
key_type key() const {
return types::key_of(v);
}
using Node::c;
leaf_ptr prev() const {
return std::static_pointer_cast<Leaf>(c[0]);
}
leaf_ptr next() const {
return std::static_pointer_cast<Leaf>(c[1]);
}
void set_prev(leaf_ptr l) {
c[0] = std::static_pointer_cast<Node>(l);
}
void set_next(leaf_ptr l) {
c[1] = std::static_pointer_cast<Node>(l);
}
};
protected:
node_ptr root_;
leaf_ptr dummy_;
size_t size_;
virtual void _init() {
root_ = create_node_at(0, 0);
dummy_ = std::make_shared<Leaf>();
root_->jump = dummy_;
dummy_->set_next(dummy_);
dummy_->set_prev(dummy_);
size_ = 0;
}
void _deinit() {
root_ = nullptr;
auto u = dummy_->next();
dummy_->set_next(nullptr);
u->set_prev(nullptr);
while (u != dummy_) {
auto n = u->next();
u->set_next(nullptr);
n->set_prev(nullptr);
u = n;
}
dummy_ = nullptr;
}
public:
BinaryTrieBase() {
_init();
}
BinaryTrieBase(const BinaryTrieBase& rhs) {
_insert_init(rhs.begin(), rhs.end());
}
virtual BinaryTrieBase& operator=(const BinaryTrieBase& rhs) {
_deinit();
_insert_init(rhs.begin(), rhs.end());
return *this;
}
BinaryTrieBase(BinaryTrieBase&&) noexcept = default;
virtual BinaryTrieBase& operator=(BinaryTrieBase&& rhs) noexcept {
_deinit();
root_ = std::move(rhs.root_);
dummy_ = std::move(rhs.dummy_);
size_ = std::move(rhs.size_);
return *this;
}
virtual ~BinaryTrieBase() {
_deinit();
}
protected:
template<class InputIt>
void _insert_init(InputIt begin, InputIt end) {
static_assert(std::is_convertible<typename std::iterator_traits<InputIt>::value_type, value_type>::value, "");
_init();
if (begin == end) return;
if (!std::is_sorted(begin, end, [](auto& l, auto& r) {
return types::key_of(l) < types::key_of(r);
})) {
for (auto it = begin; it != end; ++it)
_insert(*it);
return;
}
auto push_link = [&](leaf_ptr l) {
auto b = dummy_->prev();
l->set_prev(b);
l->set_next(dummy_);
l->prev()->set_next(l);
l->next()->set_prev(l);
};
std::array<node_ptr, W> us{};
auto grow = [&](key_type x, int k, leaf_ptr l) {
for (int i = k; i < W-1; i++) {
us[i+1] = create_node_at(x, i+1);
int c = (x >> (W-i-1)) & 1;
us[i]->c[c] = us[i+1];
us[i+1]->parent = us[i];
us[i+1]->jump = l;
}
int c = x & 1;
us[W-1]->c[c] = l;
l->parent = us[W-1];
};
us[0] = root_;
key_type x = types::key_of(*begin);
auto l = create_leaf_at(x, *begin);
push_link(l);
us[0]->jump = l;
grow(x, 0, l);
size_t sz = 1;
for (auto it = std::next(begin); it != end; ++it) {
key_type px = x;
x = types::key_of(*it);
auto m = x ^ px;
if (m == 0) continue;
// [[assume(m != 0)]]
int k = W-1;
while (m > 1) {
m >>= 1;
--k;
}
l = create_leaf_at(x, *it);
push_link(l);
for (int i = 0; i < k; i++)
if (!us[i]->c[1]) us[i]->jump = l;
us[k]->jump = nullptr;
grow(x, k, l);
++sz;
}
size_ = sz;
}
public:
template<typename InputIt>
explicit BinaryTrieBase(InputIt begin, InputIt end) {
_insert_init(begin, end);
}
size_t size() const {
return size_;
}
bool empty() const { return size() == 0; }
void clear() {
_deinit();
_init();
}
protected:
template<bool> struct iterator_base;
public:
using iterator = iterator_base<false>;
using const_iterator = iterator_base<true>;
iterator begin() {
return iterator(dummy_->next());
}
iterator end() {
return iterator(dummy_);
}
const_iterator begin() const {
return const_iterator(dummy_->next());
}
const_iterator end() const {
return const_iterator(dummy_);
}
template<class Rule>
const_iterator traverse(Rule rule) const {
auto u = root_;
for (int i = 0; i < W; i++) {
auto l = (bool)u->c[0];
auto r = (bool)u->c[1];
auto c = rule(W-1-i, l, r);
u = u->c[c];
}
return const_iterator(std::static_pointer_cast<Leaf>(u));
}
protected:
virtual std::pair<int, node_ptr> _traverse(const key_type& key,
int depth = 0,
node_ptr root = nullptr) const {
int i, c;
key_type x = key;
auto u = !root ? root_ : root;
for (i = depth; i < W; i++) {
c = (x >> (W-i-1)) & 1;
if (!u->c[c]) break;
u = u->c[c];
}
return std::make_pair(i, u);
}
iterator _lower_bound(const key_type& x) const {
auto reached = _traverse(x);
int i = reached.first;
node_ptr u = reached.second;
if (i == W) return iterator(std::static_pointer_cast<Leaf>(u));
auto l = (((x >> (W-i-1)) & 1) == 0) ? u->jump : u->jump->next();
return iterator(l);
}
iterator _upper_bound(const key_type& x) const {
auto it = _lower_bound(x);
if (types::key_of(*it) == x)
++it;
return it;
}
virtual iterator _find(const key_type& x) const {
auto reached = _traverse(x);
int i = reached.first;
node_ptr u = reached.second;
if (i == W)
return iterator(std::static_pointer_cast<Leaf>(u));
else
return end();
}
virtual node_ptr create_node_at(const key_type&, int) {
return std::make_shared<Node>();
}
virtual leaf_ptr create_leaf_at(const key_type&, const init_type& value) {
return std::make_shared<Leaf>(value);
}
virtual leaf_ptr create_leaf_at(const key_type&, init_type&& value) {
return std::make_shared<Leaf>(std::move(value));
}
template<typename Value>
iterator _emplace_impl(key_type x, int height, node_ptr forked, Value&& value) {
assert(height < W);
int i = height;
node_ptr u = forked;
auto f = u;
int c = (x >> (W-i-1)) & 1;
auto fc = c;
auto fi = i;
auto pred = c == 1 ? u->jump : u->jump->prev();
u->jump = nullptr;
auto l = create_leaf_at(x, std::forward<Value>(value));
l->set_prev(pred);
l->set_next(pred->next());
l->prev()->set_next(l);
l->next()->set_prev(l);
for (; i < W-1; i++) {
c = (x >> (W-i-1)) & 1;
assert(!u->c[c]);
u->c[c] = create_node_at(x, i+1);
u->c[c]->parent = u;
u->c[c]->jump = l;
u = u->c[c];
}
{
c = (x >> (W-i-1)) & 1;
u->c[c] = l;
u->c[c]->parent = u;
}
if (f == root_) [[unlikely]] {
f->jump = l;
} else [[likely]] {
auto v = f->parent.lock();
fi--;
while (v) {
c = x >> (W-fi-1) & 1;
if (c != fc and !v->jump)
break;
if (!v->c[fc])
v->jump = l;
v = v->parent.lock();
fi--;
}
}
size_++;
return iterator(l);
}
template<typename Value>
std::pair<iterator, bool> _insert(Value&& value) {
static_assert(std::is_convertible<Value, value_type>::value, "");
key_type x = types::key_of(value);
auto reached = _traverse(x);
int i = reached.first;
node_ptr u = reached.second;
if (i == W)
return std::make_pair(iterator(std::static_pointer_cast<Leaf>(u)), false);
return std::make_pair(_emplace_impl(x, i, u, std::forward<Value>(value)), true);
}
virtual std::pair<int, node_ptr> climb_to_lca(leaf_ptr l, key_type x) {
key_type m = x ^ types::key_of(l->v);
if (m == 0)
return std::make_pair(W, std::static_pointer_cast<Node>(l));
int h = bm::clz(m) - (64 - W);
node_ptr f = std::static_pointer_cast<Node>(l);
for (int i = W; i > h; i--)
f = f->parent.lock();
return std::make_pair(h, f);
}
template<class Value>
iterator _emplace_hint(const_iterator hint, Value&& value) {
key_type x = types::key_of(value);
if (empty())
return _emplace_impl(x, 0, root_, std::forward<Value>(value));
if (hint == end())
--hint;
int h;
node_ptr f;
std::tie(h, f) = climb_to_lca(hint.ptr_, x);
std::tie(h, f) = _traverse(x, h, f);
if (h == W)
return iterator(std::static_pointer_cast<Leaf>(f));
return _emplace_impl(x, h, f, std::forward<Value>(value));
}
virtual void erase_node_at(const key_type&, int, node_ptr) {}
virtual bool _erase(const key_type& key) {
auto it = _find(key);
if (it != end()) {
_erase_from_leaf(types::key_of(*it), it.ptr_);
return true;
} else {
return false;
}
}
template<typename Key>
iterator _erase_from_leaf(Key&& key, leaf_ptr l) {
static_assert(std::is_convertible<Key, key_type>::value, "");
key_type x = std::forward<Key>(key);
assert(x == l->key());
l->prev()->set_next(l->next());
l->next()->set_prev(l->prev());
int i,c;
auto v = std::static_pointer_cast<Node>(l);
for (i = W-1; i >= 0; i--) {
erase_node_at(x, i+1, v);
v = v->parent.lock();
c = (x >> (W-i-1)) & 1;
v->c[c] = nullptr;
if (v->c[c^1]) break;
}
auto nj = c ? l->prev() : l->next();
auto fc = c;
v->jump = nj;
v = v->parent.lock();
i--;
for (; i >= 0; i--) {
assert(v);
c = (x >> (W-i-1)) & 1;
if (c != fc) {
if (!v->jump) break;
v->jump = nj;
}
v = v->parent.lock();
}
size_--;
return iterator(l->next());
}
iterator iterator_remove_const(const const_iterator& it) {
return iterator(it.ptr_);
}
iterator iterator_remove_const(const_iterator&& it) {
return iterator(std::move(it.ptr_));
}
iterator _erase(iterator it) {
if (it == end()) return it;
return _erase_from_leaf(types::key_of(*it), it.ptr_);
}
iterator _erase(const_iterator it) {
if (it == end()) return iterator_remove_const(it);
return _erase_from_leaf(types::key_of(*it), it.ptr_);
}
protected:
template<bool Const>
struct iterator_base {
using difference_type = ptrdiff_t;
using value_type = BinaryTrieBase::value_type;
using pointer = typename std::conditional<Const,
const value_type*,
value_type*>::type;
using reference = typename std::conditional<Const,
const value_type&,
value_type&>::type;
using iterator_category = std::bidirectional_iterator_tag;
leaf_ptr ptr_;
iterator_base(leaf_ptr p) : ptr_(p) {}
template<bool C>
iterator_base(const iterator_base<C>& rhs) : ptr_(rhs.ptr_) {}
template<bool C>
iterator_base& operator=(const iterator_base<C>& rhs) {
ptr_ = rhs.ptr_;
}
template<bool C>
iterator_base(iterator_base<C>&& rhs) : ptr_(std::move(rhs.ptr_)) {}
template<bool C>
iterator_base& operator=(iterator_base<C>&& rhs) {
ptr_ = std::move(rhs.ptr_);
}
reference operator*() {
return ptr_->v;
}
pointer operator->() {
return &(ptr_->v);
}
template<bool C>
bool operator==(const iterator_base<C>& rhs) const {
return ptr_ == rhs.ptr_;
}
template<bool C>
bool operator!=(const iterator_base<C>& rhs) const {
return !operator==(rhs);
}
iterator_base& operator++() {
ptr_ = ptr_->next();
return *this;
}
iterator_base operator++(int) const {
iterator_base ret = *this;
operator++();
return ret;
}
iterator_base& operator--() {
ptr_ = ptr_->prev();
return *this;
}
iterator_base operator--(int) const {
iterator_base ret = *this;
operator--();
return ret;
}
};
};
template<typename T, typename V, uint8_t W = sizeof(T)*8>
using BinaryTrie = traits::MapTraits<BinaryTrieBase<T, V, W>>;
template<typename T, uint8_t W = sizeof(T)*8>
using BinaryTrieSet = traits::SetTraits<BinaryTrieBase<T, void, W>>;
template<typename T, typename V, uint8_t W = sizeof(T)*8>
using BinaryTrieMap = BinaryTrie<T, V, W>;
#line 3 "test/standalone/binary_trie_test.cpp"
#include <iostream>
#include <vector>
void test_constructor() {
BinaryTrie<uint32_t, uint32_t> trie;
assert(trie.empty());
assert(trie.size() == 0);
}
void test_insert() {
BinaryTrie<uint32_t, uint32_t> trie;
auto result1 = trie.insert(std::make_pair(5, 10));
auto result2 = trie.insert(std::make_pair(5, 20));
assert(trie.size() == 1);
assert(result1.second);
assert(!result2.second);
assert(result1.first->second == 10);
}
void test_find() {
BinaryTrie<uint32_t, uint32_t> trie;
trie.insert(std::make_pair(5, 10));
trie.insert(std::make_pair(3, 20));
auto result1 = trie.find(5);
auto result2 = trie.find(3);
auto result3 = trie.find(7);
assert(result1 != trie.end());
assert(result1->second == 10);
assert(result2 != trie.end());
assert(result2->second == 20);
assert(result3 == trie.end());
}
void test_erase() {
BinaryTrie<uint32_t, uint32_t> trie;
trie.insert(std::make_pair(5, 10));
trie.insert(std::make_pair(3, 20));
bool result1 = trie.erase(5);
bool result2 = trie.erase(7);
assert(result1);
assert(!result2);
assert(trie.size() == 1);
assert(trie.find(5) == trie.end());
assert(trie.find(3) != trie.end());
}
void test_range_constructor() {
std::vector<std::pair<uint32_t, uint32_t>> values = {
{1, 10},
{3, 30},
{5, 50},
};
BinaryTrie<uint32_t, uint32_t> trie(values.begin(), values.end());
assert(trie.size() == 3);
assert(trie.find(1) != trie.end());
assert(trie.find(3) != trie.end());
assert(trie.find(5) != trie.end());
}
int main() {
test_constructor();
test_insert();
test_find();
test_erase();
test_range_constructor();
std::cout << "All tests passed!" << std::endl;
return 0;
}