matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:heavy_check_mark: test/yosupo/dynamic_tree_vertex_set_path_composite.test.cpp

Code

#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_tree_vertex_set_path_composite"
#include "include/mtl/link_cut_tree.hpp"
#include "include/mtl/modular.hpp"
#include <bits/stdc++.h>
using namespace std;

using mint = Modular998244353;
struct Fn {
    mint a,b;
    Fn(mint a=1, mint b=0) : a(a), b(b) {}
    Fn(pair<mint,mint> p) : a(p.first), b(p.second) {}
    Fn operator*(const Fn& r) const {
        return {a*r.a, b*r.a + r.b};
    }
    mint eval(int x) const {
        return a * x + b;
    }
};
using LCT = LinkCutTree<Fn>;

int main() {
    int n,q; cin>>n>>q;
    LCT lct(n);
    for (int i = 0; i < n; i++) {
        int a,b; cin>>a>>b;
        lct.set(i,a,b);
    }
    for (int i = 0; i < n-1; i++) {
        int u,v; cin>>u>>v;
        lct.link(u,v);
    }
    while (q--) {
        int t; cin>>t;
        if (t == 0) {
            int u,v,w,x; cin>>u>>v>>w>>x;
            lct.cut(u,v);
            lct.link(w,x);
        } else if (t == 1) {
            int p,c,d; cin>>p>>c>>d;
            lct.set(p,c,d);
        } else {
            int u,v,x; cin>>u>>v>>x;
            cout << lct.prod(u,v).eval(x) << endl; 
        }
    }
}
#line 1 "test/yosupo/dynamic_tree_vertex_set_path_composite.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_tree_vertex_set_path_composite"
#line 2 "include/mtl/splay_tree.hpp"
#include <memory>
#include <cassert>

template<class NodeType>
struct SplayTreeNodeBase {
    using node_type = NodeType;
    using node_shared = std::shared_ptr<node_type>;
    using node_weak = std::weak_ptr<node_type>;
    node_shared l,r;
    node_weak p;
    bool rev;
    bool is_root() const {
        return p.expired() || (p.lock()->l.get() != this && p.lock()->r.get() != this); 
    }
};
template<class T>
struct SplayTreeNodeTraits {
    using node_type = typename T::node_type;
    using node_shared = typename T::node_shared;
    using node_weak = typename T::node_weak;
};

template<class Node>
struct SplayTreeBase {
    using node_traits = SplayTreeNodeTraits<Node>;
    using node_type = typename node_traits::node_type;
    using node_shared = typename node_traits::node_shared;
    using node_weak = typename node_traits::node_weak;
    SplayTreeBase() = default;

    void rotate_left(const node_shared& u) const {
        auto p = u->p.lock(), q = p->p.lock();
        p->r = u->l;
        if (p->r)
            p->r->p = p;
        u->l = p;
        p->p = u;
        u->p = q;
        if (q) {
            if (q->l == p)
                q->l = u;
            else if (q->r == p)
                q->r = u;
        }
    }
    void rotate_right(const node_shared& u) const {
        auto p = u->p.lock(), q = p->p.lock();
        p->l = u->r;
        if (p->l)
            p->l->p = p;
        u->r = p;
        p->p = u;
        u->p = q;
        if (q) {
            if (q->l == p)
                q->l = u;
            else if (q->r == p)
                q->r = u;
        }
    }
    virtual void reverse_prod(const node_shared& u) const {}
    virtual void propagate(const node_shared& u) const {}
    virtual void aggregate(const node_shared& u) const {}
    void splay(const node_shared& u) const {
        if (u->is_root()) {
            this->propagate(u);
            this->aggregate(u);
            return;
        }
        do {
            assert(u);
            auto p = u->p.lock();
            assert(p);
            if (p->is_root()) {
                this->propagate(p);
                this->propagate(u);
                if (p->l == u)
                    rotate_right(u);
                else if (p->r == u)
                    rotate_left(u);
                else throw "invalid tree";
                this->aggregate(p);
                this->aggregate(u);
            } else {
                auto q = p->p.lock();
                this->propagate(q);
                this->propagate(p);
                this->propagate(u);
                if (q->l == p) {
                    if (p->l == u) { // zig-zig
                        rotate_right(p);
                        rotate_right(u);
                        this->aggregate(q);
                        this->aggregate(p);
                    } else if (p->r == u) { // zig-zag
                        rotate_left(u);
                        rotate_right(u);
                        this->aggregate(p);
                        this->aggregate(q);
                    } else throw "invalid tree";
                } else if (q->r == p) { 
                    if (p->r == u) { // zig-zig
                        rotate_left(p);
                        rotate_left(u);
                        this->aggregate(q);
                        this->aggregate(p);
                    } else if (p->l == u) { // zig-zag
                        rotate_right(u);
                        rotate_left(u);
                        this->aggregate(p);
                        this->aggregate(q);
                    } else throw "invalid tree";
                }
                this->aggregate(u);
            }
        } while (!u->is_root());
    }

};
#line 2 "include/mtl/monoid.hpp"
#include <utility>
#if __cpp_concepts >= 202002L
#include <concepts>
#endif

template<class T, T (*op)(T, T), T (*e)()>
struct Monoid {
  T x;
  Monoid() : x(e()) {}
  template<class... Args>
  Monoid(Args&&... args) : x(std::forward<Args>(args)...) {}
  Monoid operator*(const Monoid& rhs) const {
    return Monoid(op(x, rhs.x));
  }
  const T& val() const {
    return x;
  }
};

struct VoidMonoid {
  VoidMonoid() {}
  VoidMonoid operator*(const VoidMonoid& rhs) const {
    return VoidMonoid();
  }
};

#if __cpp_concepts >= 202002L
template<class T>
concept IsMonoid = requires (T m) {
  { m * m } -> std::same_as<T>;
};
#endif

template<class T, T (*op)(T, T), T (*e)()>
struct CommutativeMonoid : public Monoid<T, op, e> {
    using Base = Monoid<T, op, e>;
    CommutativeMonoid(T x=e()) : Base(x) {}
    CommutativeMonoid operator+(const CommutativeMonoid& rhs) const {
        return CommutativeMonoid(*this * rhs);
    }
};

#if __cpp_concepts >= 202002L
template<class T>
concept IsCommutativeMonoid = requires (T m) {
  { m + m } -> std::same_as<T>;
};
#endif

template<class S, class F, S (*mapping)(F, S), F (*composition)(F, F), F (*id)()>
struct OperatorMonoid {
    F f;
    OperatorMonoid() : f(id()) {}
    template<class... Args>
    OperatorMonoid(Args&&... args) : f(std::forward<Args>(args)...) {}
    OperatorMonoid& operator*=(const OperatorMonoid& rhs) {
        f = composition(rhs.f, f);
        return *this;
    }
    S act(const S& s) const {
        return mapping(f, s);
    }
};

struct VoidOperatorMonoid {
    VoidOperatorMonoid() {}
    VoidOperatorMonoid& operator*=(const VoidOperatorMonoid& rhs) {
        return *this;
    }
    template<class T>
    T act(const T& s) const {
        return s;
    }
};

#if __cpp_concepts >= 202002L
template<class F, class S>
concept IsOperatorMonoid = requires (F f, S s) {
    { f *= f } -> std::same_as<F&>;
    { f.act(s) } -> std::same_as<S>;
};
#endif
#line 3 "include/mtl/link_cut_tree.hpp"
#include <vector>
#include <iostream>

template<class NodeType>
struct LinkCutTreeBase : public SplayTreeBase<NodeType> {
    using Base = SplayTreeBase<NodeType>;
    using node_traits = SplayTreeNodeTraits<NodeType>;
    using node_shared = typename node_traits::node_shared;
    void expose(const node_shared& x) const {
        node_shared r = nullptr;
        for (node_shared p = x; p; p = p->p.lock()) {
            Base::splay(p);
            p->r = r;
            r = p;
            this->aggregate(p);
        }
        Base::splay(x);
    }
    void evert(const node_shared& v) const {
        expose(v);
        v->rev ^= true;
        this->reverse_prod(v);
        this->propagate(v);
    }
    void cut(const node_shared& c) const {
        expose(c);
        auto l = c->l;
        c->l = nullptr;
        l->p.reset();
        this->aggregate(c);
    }
    void link(const node_shared& c, const node_shared& p) const {
        evert(c);
        expose(p);
        p->r = c;
        c->p = p;
        this->aggregate(p);
    }
    void print_tree(const node_shared& u) const {
        if (!u) return;
        if (u->l and u->l->p.lock() == u) {
            print_tree(u->l);
        }
        std::cerr<<u->m.x<<' ';
        if (u->r and u->r->p.lock() == u) {
            print_tree(u->r);
        }
    }
};

template<class M, class O>
struct LinkCutTreeNode : SplayTreeNodeBase<LinkCutTreeNode<M, O>> {
    M m, prod, rprod;
    O f;
    using Base = SplayTreeNodeBase<LinkCutTreeNode<M, O>>;
    LinkCutTreeNode() = default;
    template<class... Args>
    LinkCutTreeNode(Args&&... args) 
        : Base(), m(std::forward<Args>(args)...), prod(m), rprod(m), f() {}
};

template<class M, class O=VoidOperatorMonoid>
#if __cpp_concepts >= 202002L
requires IsMonoid<M> && IsOperatorMonoid<O, M>
#endif
struct LinkCutTree : LinkCutTreeBase<LinkCutTreeNode<M, O>> {
    using node_type = LinkCutTreeNode<M, O>;
    using Base = LinkCutTreeBase<LinkCutTreeNode<M, O>>;
    using monoid_type = M;
    using operator_monoid_type = O;
    using node_shared = typename SplayTreeNodeTraits<LinkCutTreeNode<M, O>>::node_shared;
    std::vector<node_shared> nodes;
    LinkCutTree(size_t n) : Base(), nodes(n) {
        for (size_t i = 0; i < n; ++i)
            nodes[i] = std::make_shared<node_type>();
    }
    template<class InputIt>
    LinkCutTree(InputIt first, InputIt last) : Base(), nodes(std::distance(first, last)) {
        size_t i = 0;
        for (auto it = first; it != last; ++it)
            nodes[i++] = std::make_shared<node_type>(*it);
    }
    void reverse_prod(const node_shared& u) const override {
        std::swap(u->prod, u->rprod);
    }
    void propagate(const node_shared& u) const override {
        bool cl = u->l and u->l->p.lock() == u;
        bool cr = u->r and u->r->p.lock() == u;
        if (cl) {
            u->l->m = u->f.act(u->l->m);
            u->l->prod = u->f.act(u->l->prod);
            u->l->rprod = u->f.act(u->l->rprod);
            u->l->f *= u->f;
        }
        if (cr) {
            u->r->m = u->f.act(u->r->m);
            u->r->prod = u->f.act(u->r->prod);
            u->r->rprod = u->f.act(u->r->rprod);
            u->r->f *= u->f;
        }
        if (u->rev) {
            std::swap(u->l, u->r);
            if (cr) {
                u->l->rev ^= true;
                reverse_prod(u->l);
            }
            if (cl) {
                u->r->rev ^= true;
                reverse_prod(u->r);
            }
            u->rev = false;
        }
        u->f = operator_monoid_type();
    }
    void aggregate(const node_shared& u) const override {
        u->prod = u->m;
        u->rprod = u->m;
        if (u->l and u->l->p.lock() == u) {
            u->prod = u->l->prod * u->prod;
            u->rprod = u->rprod * u->l->rprod;
        }
        if (u->r and u->r->p.lock() == u) {
            u->prod = u->prod * u->r->prod;
            u->rprod = u->r->rprod * u->rprod;
        }
    }
    void cut(size_t u, size_t v) const {
        Base::evert(nodes[u]);
        Base::expose(nodes[v]);
        auto l = nodes[v]->l;
        nodes[v]->l = nullptr;
        l->p.reset();
        this->aggregate(nodes[v]);
    }
    void link(size_t u, size_t v) const {
        Base::link(nodes[v], nodes[u]);
    }
    monoid_type prod(size_t u, size_t v) const {
        Base::evert(nodes[u]);
        Base::expose(nodes[v]);
        return nodes[v]->prod;
    }
    template<class... Args>
    void set(size_t i, Args&&... args) const {
        auto u = nodes[i];
        Base::splay(u);
        u->m = monoid_type(std::forward<Args>(args)...);
        this->aggregate(u);
    }
    void update(size_t i, const operator_monoid_type& f) const {
        auto u = nodes[i];
        Base::splay(u);
        u->m = f.act(u->m);
        this->aggregate(u);
    }
    void update(size_t u, size_t v, const operator_monoid_type& f) const {
        Base::evert(nodes[u]);
        auto nv = nodes[v];
        Base::expose(nv);
        nv->m = f.act(nv->m);
        nv->prod = f.act(nv->prod);
        nv->rprod = f.act(nv->rprod);
        nv->f *= f;
        Base::splay(nv);
    }
};
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#line 4 "include/mtl/bit_manip.hpp"
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif

namespace bm {

/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
  x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
  x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
  x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
  return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
  return std::popcount(x);
#else
  return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
  return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_zero(x);
#else
  return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
  return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_one(x);
#else
  return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
  return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
  return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
  constexpr uint64_t hmask = 0x8080808080808080ull;
  constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
  auto a = x & hmask;
  auto b = x & lmask;
  b = hmask - b;
  b = ~b;
  auto c = (a | b) & hmask;
  c *= 0x0002040810204081ull;
  return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
  uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
  return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
  if (x & 0x80)
    return 8;
  uint64_t p = uint64_t(x) * 0x0101010101010101ull;
  p -= 0x8040201008040201ull;
  p = ~p & 0x8080808080808080ull;
  p = (p >> 7) * 0x0101010101010101ull;
  p >>= 56;
  return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
  assert(x != 0);
  return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
  assert(x != 0);
  return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_zero(x);
#else
  if (x == 0)
    return 64;
  auto i = mssb8(summary(x));
  auto j = mssb8(bextr(x, 8 * i, 8));
  return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
  return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_one(x);
#else
  return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
  return clo(x);
}

inline constexpr unsigned clz8(uint8_t x) {
  return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
  x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
  x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
  x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
  x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
  x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
  x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
  return x;
}

/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::has_single_bit(x);
#else
  return x != 0 && (x & (x - 1)) == 0;
#endif
}

/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::bit_width(x);
#else
  return 64 - clz(x);
#endif
}

/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_ceil(x);
#else
  if (x == 0) return 1;
  return 1ull << bit_width(x - 1);
#endif
}

/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_floor(x);
#else
  if (x == 0) return 0;
  return 1ull << (bit_width(x) - 1);
#endif
}

} // namespace bm
#line 5 "include/mtl/modular.hpp"

template <int MOD>
class Modular {
 private:
  unsigned int val_;

 public:
  static constexpr unsigned int mod() { return MOD; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr Modular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr Modular& operator+=(Modular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular operator-() const { return {mod() - val_}; }
  constexpr Modular& operator-=(Modular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular& operator*=(Modular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr Modular pow(long long p) const {
    assert(p >= 0);
    Modular t = 1;
    Modular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr Modular pow(Modular x, long long p) {
    return x.pow(p);
  }
  constexpr Modular inv() const { return pow(mod()-2); }
  constexpr Modular& operator/=(Modular x) { return *this *= x.inv(); }
  constexpr Modular operator+(Modular x) const { return Modular(*this) += x; }
  constexpr Modular operator-(Modular x) const { return Modular(*this) -= x; }
  constexpr Modular operator*(Modular x) const { return Modular(*this) *= x; }
  constexpr Modular operator/(Modular x) const { return Modular(*this) /= x; }
  constexpr Modular& operator++() { return *this += 1; }
  constexpr Modular operator++(int) { Modular c = *this; ++(*this); return c; }
  constexpr Modular& operator--() { return *this -= 1; }
  constexpr Modular operator--(int) { Modular c = *this; --(*this); return c; }

  constexpr bool operator==(Modular x) const { return val() == x.val(); }
  constexpr bool operator!=(Modular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr Modular sqrt() const {
    if (!is_square()) 
      throw std::runtime_error("not square");
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= Modular(2).pow((mod()-1)/4);
      return x;
    } else {
      Modular d = 2;
      while (d.is_square())
        d += 1;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      Modular dt = 1;
      Modular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const Modular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, Modular& x) {
    return is >> x.val_;
  }

};

using Modular998244353 = Modular<998244353>;
using Modular1000000007 = Modular<(int)1e9+7>;

template<int Id=0>
class DynamicModular {
 private:
  static unsigned int mod_;
  unsigned int val_;

 public:
  static unsigned int mod() { return mod_; }
  static void set_mod(unsigned int m) { mod_ = m; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr DynamicModular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr DynamicModular& operator+=(DynamicModular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular operator-() const { return {mod() - val_}; }
  constexpr DynamicModular& operator-=(DynamicModular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular& operator*=(DynamicModular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr DynamicModular pow(long long p) const {
    assert(p >= 0);
    DynamicModular t = 1;
    DynamicModular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr DynamicModular pow(DynamicModular x, long long p) {
    return x.pow(p);
  }
  // TODO: implement when mod is not prime
  constexpr DynamicModular inv() const { return pow(mod()-2); }
  constexpr DynamicModular& operator/=(DynamicModular x) { return *this *= x.inv(); }
  constexpr DynamicModular operator+(DynamicModular x) const { return DynamicModular(*this) += x; }
  constexpr DynamicModular operator-(DynamicModular x) const { return DynamicModular(*this) -= x; }
  constexpr DynamicModular operator*(DynamicModular x) const { return DynamicModular(*this) *= x; }
  constexpr DynamicModular operator/(DynamicModular x) const { return DynamicModular(*this) /= x; }
  constexpr DynamicModular& operator++() { return *this += 1; }
  constexpr DynamicModular operator++(int) { DynamicModular c = *this; ++(*this); return c; }
  constexpr DynamicModular& operator--() { return *this -= 1; }
  constexpr DynamicModular operator--(int) { DynamicModular c = *this; --(*this); return c; }

  constexpr bool operator==(DynamicModular x) const { return val() == x.val(); }
  constexpr bool operator!=(DynamicModular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return val() == 0 or pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr DynamicModular sqrt() const {
    // assert mod is prime
    if (!is_square()) 
      throw std::runtime_error("not square");
    if (val() < 2)
      return val();
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= DynamicModular(2).pow((mod()-1)/4);
      return x;
    } else {
      DynamicModular d = 2;
      while (d.is_square())
        ++d;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      DynamicModular dt = 1;
      DynamicModular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const DynamicModular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, DynamicModular& x) {
    return is >> x.val_;
  }

};
template<int Id>
unsigned int DynamicModular<Id>::mod_;

#line 264 "include/mtl/modular.hpp"

template<class ModInt>
struct ModularUtil {
  static constexpr int mod = ModInt::mod();
  static struct inv_table {
    std::vector<ModInt> tb{0,1};
    inv_table() : tb({0,1}) {}
  } inv_;
  void set_inv(int n) {
    int m = inv_.tb.size();
    if (m > n) return;
    inv_.tb.resize(n+1);
    for (int i = m; i < n+1; i++)
      inv_.tb[i] = -inv_.tb[mod % i] * (mod / i);
  }
  ModInt& inv(int i) {
    set_inv(i);
    return inv_.tb[i];
  }
};
template<class ModInt>
typename ModularUtil<ModInt>::inv_table ModularUtil<ModInt>::inv_;

#include <array>

namespace math {

constexpr int mod_pow_constexpr(int x, int p, int m) {
  long long t = 1;
  long long u = x;
  while (p) {
    if (p & 1) {
      t *= u;
      t %= m;
    }
    u *= u;
    u %= m;
    p >>= 1;
  }
  return (int) t;
}

constexpr int primitive_root_constexpr(int m) {
  if (m == 2) return 1;
  if (m == 167772161) return 3;
  if (m == 469762049) return 3;
  if (m == 754974721) return 11;
  if (m == 880803841) return 26;
  if (m == 998244353) return 3;

  std::array<int, 20> divs{};
  int cnt = 0;
  int x = m-1;
  if (x % 2 == 0) {
    divs[cnt++] = 2;
    x >>= bm::ctz(x);
  }
  for (int d = 3; d*d <= x; d += 2) {
    if (x % d == 0) {
      divs[cnt++] = d;
      while (x % d == 0)
        x /= d;
    }
  }
  if (x > 1) divs[cnt++] = x;
  for (int g = 2; g < m; g++) {
    bool ok = true;
    for (int i = 0; i < cnt; i++) {
      if (mod_pow_constexpr(g, (m-1) / divs[i], m) == 1) {
        ok = false;
        break;
      }
    }
    if (ok) return g;
  }
  return -1;
}

template<int m>
constexpr int primitive_root = primitive_root_constexpr(m);

}
#line 4 "test/yosupo/dynamic_tree_vertex_set_path_composite.test.cpp"
#include <bits/stdc++.h>
using namespace std;

using mint = Modular998244353;
struct Fn {
    mint a,b;
    Fn(mint a=1, mint b=0) : a(a), b(b) {}
    Fn(pair<mint,mint> p) : a(p.first), b(p.second) {}
    Fn operator*(const Fn& r) const {
        return {a*r.a, b*r.a + r.b};
    }
    mint eval(int x) const {
        return a * x + b;
    }
};
using LCT = LinkCutTree<Fn>;

int main() {
    int n,q; cin>>n>>q;
    LCT lct(n);
    for (int i = 0; i < n; i++) {
        int a,b; cin>>a>>b;
        lct.set(i,a,b);
    }
    for (int i = 0; i < n-1; i++) {
        int u,v; cin>>u>>v;
        lct.link(u,v);
    }
    while (q--) {
        int t; cin>>t;
        if (t == 0) {
            int u,v,w,x; cin>>u>>v>>w>>x;
            lct.cut(u,v);
            lct.link(w,x);
        } else if (t == 1) {
            int p,c,d; cin>>p>>c>>d;
            lct.set(p,c,d);
        } else {
            int u,v,x; cin>>u>>v>>x;
            cout << lct.prod(u,v).eval(x) << endl; 
        }
    }
}

Test cases

Env Name Status Elapsed Memory
g++ example_00 :heavy_check_mark: AC 6 ms 3 MB
g++ example_01 :heavy_check_mark: AC 5 ms 3 MB
g++ max_random_00 :heavy_check_mark: AC 1492 ms 28 MB
g++ max_random_01 :heavy_check_mark: AC 1435 ms 28 MB
g++ max_random_02 :heavy_check_mark: AC 1331 ms 28 MB
g++ medium_00 :heavy_check_mark: AC 8 ms 3 MB
g++ medium_01 :heavy_check_mark: AC 6 ms 3 MB
g++ medium_02 :heavy_check_mark: AC 7 ms 3 MB
g++ medium_03 :heavy_check_mark: AC 6 ms 3 MB
g++ medium_04 :heavy_check_mark: AC 9 ms 4 MB
g++ random_00 :heavy_check_mark: AC 902 ms 19 MB
g++ random_01 :heavy_check_mark: AC 1000 ms 22 MB
g++ random_02 :heavy_check_mark: AC 656 ms 10 MB
g++ random_03 :heavy_check_mark: AC 587 ms 24 MB
g++ random_04 :heavy_check_mark: AC 452 ms 5 MB
g++ small_00 :heavy_check_mark: AC 5 ms 3 MB
g++ small_01 :heavy_check_mark: AC 5 ms 3 MB
g++ small_02 :heavy_check_mark: AC 5 ms 3 MB
g++ small_03 :heavy_check_mark: AC 5 ms 3 MB
g++ small_04 :heavy_check_mark: AC 5 ms 3 MB
clang++ example_00 :heavy_check_mark: AC 6 ms 3 MB
clang++ example_01 :heavy_check_mark: AC 5 ms 3 MB
clang++ max_random_00 :heavy_check_mark: AC 1411 ms 28 MB
clang++ max_random_01 :heavy_check_mark: AC 1436 ms 28 MB
clang++ max_random_02 :heavy_check_mark: AC 1649 ms 28 MB
clang++ medium_00 :heavy_check_mark: AC 7 ms 4 MB
clang++ medium_01 :heavy_check_mark: AC 6 ms 3 MB
clang++ medium_02 :heavy_check_mark: AC 6 ms 3 MB
clang++ medium_03 :heavy_check_mark: AC 6 ms 3 MB
clang++ medium_04 :heavy_check_mark: AC 8 ms 3 MB
clang++ random_00 :heavy_check_mark: AC 934 ms 19 MB
clang++ random_01 :heavy_check_mark: AC 1034 ms 22 MB
clang++ random_02 :heavy_check_mark: AC 624 ms 10 MB
clang++ random_03 :heavy_check_mark: AC 628 ms 24 MB
clang++ random_04 :heavy_check_mark: AC 361 ms 5 MB
clang++ small_00 :heavy_check_mark: AC 5 ms 3 MB
clang++ small_01 :heavy_check_mark: AC 5 ms 3 MB
clang++ small_02 :heavy_check_mark: AC 5 ms 3 MB
clang++ small_03 :heavy_check_mark: AC 5 ms 3 MB
clang++ small_04 :heavy_check_mark: AC 5 ms 3 MB
Back to top page