matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:heavy_check_mark: test/string/dynamic_ac_machine_test.cpp

Code

#define STANDALONE
#include "include/mtl/string/dynamic_ac_machine.hpp"

#include <iostream>
#include <cassert>
#include <vector>
#include <cstring>

bool compare_output(std::vector<std::vector<std::string>>& a, std::vector<std::vector<std::string>>& b) {
  std::vector<std::string> empty;
  for (size_t i = 0; i < std::max(a.size(), b.size()); i++) {
    std::vector<std::string>& l = ((i >= a.size()) ? empty : a[i]);
    std::vector<std::string>& r = ((i >= b.size()) ? empty : b[i]);
    if (l != r) {
      return true;
    }
  }
  return false;
}

int main() {
  std::vector<std::string> keys{"ab", "bc", "bab", "d", "abcde"};
  std::vector<std::vector<std::string>> outputs{
      {},
      {},
      {"ab"},
      {},
      {"bc"},
      {},
      {"bab", "ab"},
      {"d"},
      {"bc"},
      {"d"},
      {"abcde"},
  };
  std::vector<std::vector<std::string>> outputs2{
      {},
      {},
      {"ab"},
      {},
      {"bc"},
      {},
      {"bab", "ab"},
      {"d"},
      {"bc"},
      {"d"},
      {"abcde", "bcde"},
      {"d"},
      {"bcde"}
  };

  DynamicAcMachine<int> ac;
  int id = 0;
  for (auto& key : keys) ac.insert(key, id++);
  auto _ret = ac.all_output();
  std::vector<std::vector<std::string>> ret;
  for (auto& l : _ret) {
    ret.emplace_back();
    for (auto& [k,v] : l) ret.back().push_back(k);
  }
  auto v = compare_output(outputs, ret);
  if (compare_output(outputs, ret)) {
    assert(false);
    exit(EXIT_FAILURE);
  }

  ac.insert("bcde", id++);
  _ret = ac.all_output();
  ret = {};
  for (auto& l : _ret) {
    ret.emplace_back();
    for (auto& [k,v] : l) ret.back().push_back(k);
  }
  if (compare_output(outputs2, ret)) {
    assert(false);
    exit(EXIT_FAILURE);
  }

  ac.erase("bcde");
  _ret = ac.all_output();
  ret = {};
  for (auto& l : _ret) {
    ret.emplace_back();
    for (auto& [k,v] : l) ret.back().push_back(k);
  }
//  for (auto o : ret) {
//    std::cout<<"{ ";
//    for (auto k : o)
//      std::cout<<k<<' ';
//    std::cout<<"}"<<std::endl;
//  }
  if (compare_output(outputs, ret)) {
    assert(false);
    exit(EXIT_FAILURE);
  }

  std::cout << "OK" << std::endl;
}
#line 1 "test/string/dynamic_ac_machine_test.cpp"
#define STANDALONE
#line 2 "include/mtl/string/trie.hpp"
#include <utility>
#include <vector>
#include <map>
#include <memory>

template<typename T, typename KeyType, typename LabelType>
class Trie {
 public:
  using key_type = KeyType;
  using mapped_type = T;
  using value_type = std::pair<key_type, mapped_type>;
  using pointer = std::shared_ptr<value_type>;

  using label_type = LabelType;

  /*
   * node:
   * - trans: (\Sigma) -> V
   * - parent: V
   */
  struct TrieNode {
    std::map<label_type, int> trans;
    pointer ptr = nullptr;
  };
  std::vector<TrieNode> nodes;
  using node_iterator = typename decltype(nodes)::iterator;

  Trie() : nodes(1) {}

  node_iterator find(const key_type& key) const {
    auto nit = nodes.begin();
    for (auto c : key) {
      auto next = nit->trans.find(c);
      if (next == nit->trans.end())
        return nodes.end();
      nit = next;
    }
    if (nit->ptr) {
      return nit;
    } else {
      return nodes.end();
    }
  }

  node_iterator insert(const key_type& key, const mapped_type& value) {
    auto nit = nodes.begin();
    for (auto c : key) {
      auto next = nit->trans.find(c);
      if (next == nit->trans.end()) {
        nit = add_edge(nit-nodes.begin(), c);
      } else {
        nit = next;
      }
    }
    if (!nit->ptr) {
      nit->ptr = std::make_shared<value_type>(key, value);
    }
    return nit;
  }

  node_iterator add_edge(int node, label_type c) {
    auto it = nodes[node].trans.find(c);
    if (it != nodes[node].trans.end()) {
      return nodes[node].begin() + it->second;
    } else {
      nodes[node].trans[c] = nodes.size();
      nodes.emplace_back();
      return std::prev(nodes.end());
    }
  }

};
#line 3 "include/mtl/string/ac_machine.hpp"
#include <iterator>
#include <cassert>
#include <cstring>
#line 7 "include/mtl/string/ac_machine.hpp"
#include <queue>
#include <unordered_map>
#line 10 "include/mtl/string/ac_machine.hpp"
#include <algorithm>
#include <iostream>
#include <typeinfo>
#include <limits>

template<typename T, typename = std::void_t<>>
struct AcMachineNodeTraits {};

template<typename T>
struct AcMachineNodeTraits<T,
    std::void_t<typename T::id_type,
                typename T::index_type,
                typename T::char_type,
                typename T::edge_type>> {
  using id_type = typename T::id_type;
  using index_type = typename T::index_type;
  using char_type = typename T::char_type;
  using edge_type = typename T::edge_type;
};

struct _AcMachineNode {
  using id_type = int;
  using index_type = int;
  using char_type = char;
  using edge_type = std::unordered_map<char_type, index_type>;
  id_type id = -1;
  edge_type e;
  index_type fail = -1;
  void set_fail(index_type u, index_type f, _AcMachineNode& r) {
    fail = f;
  }
};

template<typename T, class Node>
class _AcMachine {
 public:
  using key_type = std::string;
  using mapped_type = T;
  using value_type = std::pair<key_type, mapped_type>;
  using node_traits = AcMachineNodeTraits<Node>;
  using id_type = typename node_traits::id_type;
  using index_type = typename node_traits::index_type;
  using char_type = typename node_traits::char_type;
 protected:
  std::vector<value_type> container_;
  std::vector<Node> nodes_;

 protected:
  index_type _go(index_type u, char_type c) const {
    auto& node = nodes_[u];
    auto it = node.e.find(c);
    if (it != node.e.end())
      return it->second;
    else
      return u == 0 ? 0 : -1;
  }
  void _insert_key(const std::string& key, const mapped_type& value) {
    if (nodes_.empty())
      nodes_.emplace_back();
    index_type u = 0;
    size_t i;
    for (i = 0; i < key.size(); i++) {
      auto it = nodes_[u].e.find(key[i]);
      if (it == nodes_[u].e.end())
        break;
      u = it->second;
    }
    if (i == key.size())
      return;
    for (; i < key.size(); i++) {
      auto next = nodes_.size();
      nodes_.emplace_back();
      nodes_[u].e[key[i]] = next;
      u = next;
    }
    nodes_[u].id = (id_type) container_.size();
    container_.emplace_back(key, value);
  }
 private:
  template<typename It>
  void _build_trie(It begin, It end) {
    for (auto it = begin; it != end; ++it) {
      _insert_key(*it);
    }
  }
  void _build_fail() {
    std::queue<index_type> idx;
    for (auto s : nodes_[0].e) {
      index_type next = s.second;
      idx.push(next);
      nodes_[next].set_fail(next, 0, nodes_[0]);
    }
    while (!idx.empty()) {
      auto id = idx.front(); idx.pop();
      for (auto s : nodes_[id].e) {
        char_type c = s.first;
        index_type next = s.second;
        idx.push(next);
        auto state = nodes_[id].fail;
        index_type target;
        while ((target = _go(state, c)) == -1)
          state = nodes_[state].fail;
        nodes_[next].set_fail(next, target, nodes_[target]);
      }
    }
  }
  template<typename It>
  void _build(It begin, It end) {
    using traits = std::iterator_traits<It>;
    static_assert(std::is_convertible<typename traits::value_type, std::string>::value, "");
    static_assert(std::is_base_of<std::forward_iterator_tag, typename traits::iterator_category>::value, "");

    _build_trie(begin, end);
    _build_fail();
  }

  std::vector<id_type> _output_id(index_type u) const {
    std::vector<id_type> ret;
    while (u != 0) {
      if (nodes_[u].id != -1)
        ret.push_back(nodes_[u].id);
      u = nodes_[u].fail;
    }
    return ret;
  }

  std::vector<value_type> _output(index_type u) const {
    std::vector<value_type> ret;
    while (u != 0) {
      auto id = nodes_[u].id;
      if (id != -1)
        ret.push_back(container_[id]);
      u = nodes_[u].fail;
    }
    return ret;
  }

 public:
  _AcMachine() = default;

  void insert(const key_type& key, const mapped_type& value) {
    _insert_key(key, value);
  }
  void build() {
    _build_fail();
  }

  std::vector<std::vector<value_type>> all_output() const {
    std::vector<std::vector<value_type>> res(nodes_.size());
    std::queue<index_type> idx;
    idx.push(0);
    while (!idx.empty()) {
      auto u = idx.front(); idx.pop();
      res[u] = _output(u);
      for (auto s : nodes_[u].e) {
        idx.push(s.second);
      }
    }
    return res;
  }

  struct key_iterator {
   public:
    using value_type = _AcMachine::value_type;
    using reference = const value_type&;
    using pointer = const value_type*;
    using iterator_category = std::forward_iterator_tag;
    using difference_type = long long;
    const _AcMachine* ac_;
    index_type u_;
    void _forward_until_data() {
      while (u_ != 0 and ac_->nodes_[u_].id == -1) {
        u_ = ac_->nodes_[u_].fail;
      }
    }
    key_iterator& to_exact() {
      _forward_until_data();
      return *this;
    }
    key_iterator() = default;
    explicit key_iterator(const _AcMachine* ac, index_type u) : ac_(ac), u_(u) {}
    reference operator*() const { return ac_->container_[ac_->nodes_[u_].id]; }
    pointer operator->() const { return &ac_->container_[ac_->nodes_[u_].id]; }
    bool operator==(const key_iterator& r) const { return u_ == r.u_; }
    bool operator!=(const key_iterator& r) const { return !(*this == r); }
    key_iterator& operator++() {
      u_ = ac_->nodes_[u_].fail;
      to_exact();
      return *this;
    }
    key_iterator operator++(int) {
      key_iterator ret = *this;
      ++*this;
      return ret;
    }
    key_iterator& push(char_type c) {
      index_type target;
      while ((target = ac_->_go(u_, c)) == -1)
        u_ = ac_->nodes_[u_].fail;
      u_ = target;
      return *this;
    }
    key_iterator pushed(char_type c) const {
      return key_iterator(*this).push(c);
    }
  };
  key_iterator key_begin() const {
    return key_iterator(this, 0);
  }
  key_iterator key_end() const {
    return key_iterator(this, 0);
  }
  std::vector<std::pair<size_t, key_iterator>> find_all(const std::string& text) const {
    std::vector<std::pair<size_t, key_iterator>> ret;
    auto it = key_begin();
    for (size_t i = 0; i < text.size(); i++) {
      it.push(text[i]);
      auto c = it;
      c.to_exact();
      if (c != key_end()) {
        ret.emplace_back(i+1, c);
      }
    }
    return ret;
  }
  std::pair<std::pair<size_t, key_iterator>, bool> find(const std::string& text) const {
    auto it = key_begin();
    for (size_t i = 0; i < text.size(); i++) {
      it.push(text[i]);
      auto c = it;
      c.to_exact();
      if (c != key_end()) {
        return {{i+1, c}, true};
      }
    }
    return {{}, false};
  }
  key_iterator find_suffix(const std::string& text) const {
    auto it = key_begin();
    for (auto c : text) {
      it.push(c);
    }
    return it.to_exact();
  }
};
template<typename T>
using AcMachine = _AcMachine<T, _AcMachineNode>;
#line 4 "include/mtl/string/dynamic_ac_machine.hpp"
#include <unordered_set>

struct _DynamicAcMachineNode : public _AcMachineNode {
  using link_type = std::unordered_set<_AcMachineNode::index_type>;
  link_type ifail;
  void set_fail(index_type u, index_type f, _DynamicAcMachineNode& r) {
    fail = f;
    r.ifail.insert(u);
  }
};

template<typename T>
class DynamicAcMachine : public _AcMachine<T, _DynamicAcMachineNode> {
  using super = _AcMachine<T, _DynamicAcMachineNode>;
 public:
  using key_type = typename super::key_type;
  using mapped_type = typename super::mapped_type;
  using value_type = typename super::value_type;
  using id_type = typename super::id_type;
  using index_type = typename super::index_type;
  using char_type = typename super::char_type;

 public:
  DynamicAcMachine() : _AcMachine<T, _DynamicAcMachineNode>() {}
  void insert(const std::string& key, const mapped_type& value) {
    if (super::nodes_.empty()) {
      super::nodes_.emplace_back();
    }
    index_type u = 0;
    size_t d;
    std::vector<index_type> s(key.size());
    // Traverse on the trie.
    size_t i;
    for (i = 0; i < key.size(); i++) {
      char_type c = key[i];
      auto it = super::nodes_[u].e.find(c);
      if (it == super::nodes_[u].e.end())
        break;
      u = it->second;
      s[i] = u;
    }
    if (i == key.size())
      return;
    // Add new nodes.
    d = i;
    for (; i < key.size(); i++) {
      char_type c = key[i];
      index_type next = (index_type) super::nodes_.size();
      super::nodes_.emplace_back();
      super::nodes_[u].e[c] = next;
      u = next;
      s[i] = u;
    }
    super::nodes_[u].id = (id_type) super::container_.size();
    super::container_.emplace_back(key, value);
    // Set fail of new nodes.
    for (size_t k = d; k < key.size(); k++) {
      if (k == 0) {
        super::nodes_[s[k]].set_fail(s[k], 0, super::nodes_[0]);
      } else {
        auto state = super::nodes_[s[k-1]].fail;
        index_type next;
        while ((next = super::_go(state, key[k])) == -1)
          state = super::nodes_[state].fail;
        super::nodes_[s[k]].set_fail(s[k], next, super::nodes_[next]);
      }
    }
    // Update fail of nodes that 'failing' to new nodes.
    std::queue<std::pair<index_type, size_t>> qs;
    for (auto v : super::nodes_[d == 0 ? 0 : s[d-1]].ifail)
      qs.emplace(v, d);
    while (!qs.empty()) {
      auto q = qs.front(); qs.pop();
      index_type v = q.first;
      size_t k = q.second;
      index_type next;
      while (k < key.size() and (next = super::_go(v, key[k])) != -1) {
        if (super::nodes_[next].fail != -1) {
          auto pf = super::nodes_[super::nodes_[next].fail];
          pf.ifail.erase(next);
        }
        super::nodes_[next].set_fail(next, s[k], super::nodes_[s[k]]);
        v = next;
        ++k;
      }
      if (k < key.size()) {
        for (auto x : super::nodes_[v].ifail)
          qs.emplace(x, k);
      }
    }
  }

  void erase(const std::string& key) {
    index_type u = 0;
    std::vector<index_type> s(key.size());
    for (size_t i = 0; i < key.size(); i++) {
      char_type c = key[i];
      auto it = super::nodes_[u].e.find(c);
      if (it == super::nodes_[u].e.end()) {
        // key is not contained on the trie.
        return;
      }
      u = it->second;
      s[i] = u;
    }

    if (super::nodes_[s[key.size()-1]].id == -1) {
      // key is not contained on the trie.
      return;
    }
    // Pop removed nodes on fail links.
    super::nodes_[s[key.size()-1]].id = -1;
    if (super::nodes_[s[key.size()-1]].e.empty()) {
      size_t d = key.size()-1;
      while (d > 0 and super::nodes_[s[d-1]].id == -1 and super::nodes_[s[d-1]].e.size() == 1)
        d--;
      for (size_t k = d; k < key.size(); k++) {
        index_type f = super::nodes_[s[k]].fail;
        super::nodes_[f].ifail.erase(s[k]);
        for (auto v : super::nodes_[s[k]].ifail) {
          auto target = super::nodes_[s[k]].fail;
          super::nodes_[target].ifail.erase(s[k]);
          super::nodes_[v].set_fail(v, target, super::nodes_[target]);
        }
      }
      super::nodes_[d == 0 ? 0 : s[d-1]].e.erase(key[d]);
    }
  }
};
#line 3 "test/string/dynamic_ac_machine_test.cpp"

#line 8 "test/string/dynamic_ac_machine_test.cpp"

bool compare_output(std::vector<std::vector<std::string>>& a, std::vector<std::vector<std::string>>& b) {
  std::vector<std::string> empty;
  for (size_t i = 0; i < std::max(a.size(), b.size()); i++) {
    std::vector<std::string>& l = ((i >= a.size()) ? empty : a[i]);
    std::vector<std::string>& r = ((i >= b.size()) ? empty : b[i]);
    if (l != r) {
      return true;
    }
  }
  return false;
}

int main() {
  std::vector<std::string> keys{"ab", "bc", "bab", "d", "abcde"};
  std::vector<std::vector<std::string>> outputs{
      {},
      {},
      {"ab"},
      {},
      {"bc"},
      {},
      {"bab", "ab"},
      {"d"},
      {"bc"},
      {"d"},
      {"abcde"},
  };
  std::vector<std::vector<std::string>> outputs2{
      {},
      {},
      {"ab"},
      {},
      {"bc"},
      {},
      {"bab", "ab"},
      {"d"},
      {"bc"},
      {"d"},
      {"abcde", "bcde"},
      {"d"},
      {"bcde"}
  };

  DynamicAcMachine<int> ac;
  int id = 0;
  for (auto& key : keys) ac.insert(key, id++);
  auto _ret = ac.all_output();
  std::vector<std::vector<std::string>> ret;
  for (auto& l : _ret) {
    ret.emplace_back();
    for (auto& [k,v] : l) ret.back().push_back(k);
  }
  auto v = compare_output(outputs, ret);
  if (compare_output(outputs, ret)) {
    assert(false);
    exit(EXIT_FAILURE);
  }

  ac.insert("bcde", id++);
  _ret = ac.all_output();
  ret = {};
  for (auto& l : _ret) {
    ret.emplace_back();
    for (auto& [k,v] : l) ret.back().push_back(k);
  }
  if (compare_output(outputs2, ret)) {
    assert(false);
    exit(EXIT_FAILURE);
  }

  ac.erase("bcde");
  _ret = ac.all_output();
  ret = {};
  for (auto& l : _ret) {
    ret.emplace_back();
    for (auto& [k,v] : l) ret.back().push_back(k);
  }
//  for (auto o : ret) {
//    std::cout<<"{ ";
//    for (auto k : o)
//      std::cout<<k<<' ';
//    std::cout<<"}"<<std::endl;
//  }
  if (compare_output(outputs, ret)) {
    assert(false);
    exit(EXIT_FAILURE);
  }

  std::cout << "OK" << std::endl;
}
Back to top page