matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:heavy_check_mark: test/yosupo/vertex_set_path_composite.test.cpp

Code

#define PROBLEM "https://judge.yosupo.jp/problem/vertex_set_path_composite"
#include "include/mtl/hld.hpp"
#include "include/mtl/segment_hld.hpp"
#include "include/mtl/modular.hpp"
#include <bits/stdc++.h>
using namespace std;

using mint = Modular<998244353>;

struct Composite {
    mint a, b;
    Composite(mint a=1, mint b=0):a(a),b(b) {}
    Composite(pair<int,int> p):a(p.first),b(p.second) {}
    Composite operator*(const Composite& o) const { 
        return Composite(a*o.a, b*o.a + o.b);
    }
    mint eval(mint x) const {
        return a*x+b;
    }
};

int main() {
    int n,q; cin>>n>>q;
    vector<pair<int,int>> C(n);
    for (int i = 0; i < n; i++) {
        int a,b; cin>>a>>b;
        C[i] = {a,b};
    }
    Hld T(n);
    for (int i = 0; i < n-1; i++) {
        int u,v; cin>>u>>v;
        T.add_edge(u,v);
    } 
    T.build();
    decltype(C) D(n);
    for (int i = 0; i < n; i++) 
        D[T.in[i]] = C[i];
    SegmentHld<Composite> path_sum(T, D.begin(), D.end());
    for (int i = 0; i < q; i++) {
        int t; cin>>t;
        if (t == 0) {
            int p,c,d; cin>>p>>c>>d;
            T.set(p, [&](auto i, auto v) {path_sum.set(i, v);}, Composite(c, d));
        } else {
            int u,v,x; cin>>u>>v>>x;
            auto lq = [&](int l, int r) { 
                return path_sum.query(l, r); 
            };
            auto rq = [&](int l, int r) { 
                return path_sum.reverse_query(l, r); 
            };
            auto ret = T.query(u,v,lq,rq).eval(x);
            cout << ret << endl;
        }
    }
}
#line 1 "test/yosupo/vertex_set_path_composite.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/vertex_set_path_composite"
#line 2 "include/mtl/hld.hpp"
#include <cstddef>
#include <vector>

struct Hld {
  int r,n;
  std::vector<std::vector<int>> edge;
  std::vector<int> size, in, out, head, rev, par, depth, clen;
 private:
  void dfs_sz(int v, int p, int d) {
    par[v] = p;
    size[v] = 1;
    if (!edge[v].empty() and edge[v][0] == p)
      std::swap(edge[v][0], edge[v].back());
    for (auto& t:edge[v]) {
      if (t == p) continue;
      dfs_sz(t, v, d+1);
      size[v] += size[t];
      if (size[edge[v][0]] < size[t])
        std::swap(edge[v][0], t);
    }
  }
  void dfs_hld(int v, int p, int& times) {
    in[v] = times++;
    rev[in[v]] = v;
    clen[v] = 1;
    if (!edge[v].empty() and edge[v][0] != p) {
      int t = edge[v][0];
      head[t] = head[v];
      depth[t] = depth[v];
      dfs_hld(t, v, times);
      clen[v] += clen[t];
    }
    for (size_t i = 1; i < edge[v].size(); i++) {
      int t = edge[v][i];
      if (t == p) continue;
      head[t] = t;
      depth[t] = depth[v] + 1;
      dfs_hld(t, v, times);
    }
    out[v] = times;
  }

 public:
  Hld(int n) : r(0), n(n), edge(n), size(n), in(n, -1), out(n, -1), head(n, -1), rev(n, -1), par(n, -1), depth(n, -1), clen(n) {}

  inline void add_edge(int a, int b) {
    edge[a].push_back(b);
    edge[b].push_back(a);
  }

  void build(int root = 0) {
    r = root;
    dfs_sz(root, -1, 0);
    int t = 0;
    head[root] = root;
    depth[root] = 0;
    dfs_hld(root, -1, t);
  }

  inline int lca(int a, int b) const {
    if (depth[a] > depth[b]) std::swap(a, b);
    while (depth[a] < depth[b]) {
      b = par[head[b]];
    }
    while (head[a] != head[b]) {
      a = par[head[a]];
      b = par[head[b]];
    }
    return in[a] < in[b] ? a : b;
  }

 private:
  template<class Query, class ReverseQuery>
  auto _query(int u, int v, Query Q, ReverseQuery RQ, bool include_lca) const -> decltype(Q(0,0)) {
    using T = decltype(Q(0,0));
    T um, vm;
    auto u_up = [&]() {
      um = um * (T)RQ(in[head[u]], in[u]+1);
      u = par[head[u]];
    };
    auto v_up = [&]() {
      vm = (T)Q(in[head[v]], in[v]+1) * vm;
      v = par[head[v]];
    };
    while (depth[u] > depth[v])
      u_up();
    while (depth[u] < depth[v])
      v_up();
    while (head[u] != head[v]) {
      u_up();
      v_up();
    }
    if (in[u] < in[v]) {
      int l = include_lca ? in[u] : in[u]+1;
      return um * (T)Q(l, in[v]+1) * vm;
    } else {
      int l = include_lca ? in[v] : in[v]+1;
      return um * (T)RQ(l, in[u]+1) * vm;
    }
  }

 public:
  template<class Query, class ReverseQuery>
  auto query(int u, int v, Query Q, ReverseQuery RQ, bool include_lca = true) const -> decltype(Q(0,0)) {
    return _query(u, v, Q, RQ, include_lca);
  }

  /// Query for commutative monoid
  template<class Query>
  auto query(int u, int v, Query Q, bool include_lca = true) const -> decltype(Q(0,0)) {
    return _query(u, v, Q, Q, include_lca);
  }

  template<class Set, class T>
  void set(int i, Set S, T&& val) const {
    S(in[i], std::forward<T>(val));
  }

  template<typename Upd, typename T>
  void update(int u, int v, Upd U, const T& val, bool include_lca = true) const {
    if (depth[u] > depth[v]) std::swap(u,v);
    auto up = [&](int& v) {
      U(in[head[v]], in[v]+1, val);
      v = par[head[v]];
    };
    while (depth[u] < depth[v]) {
      up(v);
    }
    while (head[u] != head[v]) {
      up(u);
      up(v);
    }
    if (in[u] > in[v]) std::swap(u,v);
    int l = include_lca ? in[u] : in[u]+1;
    U(l, in[v]+1, val);
  }

public:
  template<class Add, class Sum>
  void subtree_build(Add A, Sum S) const {
    dfs_subtree_build(A, S, r);
  }
 private:
  template<class Add, class Sum>
  void dfs_subtree_build(Add A, Sum S, int u) const {
    for (size_t i = 0; i < edge[u].size(); i++) {
      auto v = edge[u][i];
      if (v == par[u]) continue;
      dfs_subtree_build(A, S, v);
      if (i > 0)
        A(in[u], S(in[v], in[v]+clen[v]));
    }
  }
 public:
  template<class T, class Sum>
  T subtree_sum(int r, Sum S) const {
    return (T)S(in[r], in[r]+clen[r]);
  }
  template<class T, class Add>
  void subtree_point_add(int u, Add A, const T& val) const {
    while (u != -1) {
      A(in[u], val);
      u = par[head[u]];
    }
  }
};
#line 2 "include/mtl/monoid.hpp"
#include <utility>
#if __cpp_concepts >= 202002L
#include <concepts>
#endif

template<class T, T (*op)(T, T), T (*e)()>
struct Monoid {
  T x;
  Monoid() : x(e()) {}
  template<class... Args>
  Monoid(Args&&... args) : x(std::forward<Args>(args)...) {}
  Monoid operator*(const Monoid& rhs) const {
    return Monoid(op(x, rhs.x));
  }
  const T& val() const {
    return x;
  }
};

struct VoidMonoid {
  VoidMonoid() {}
  VoidMonoid operator*(const VoidMonoid& rhs) const {
    return VoidMonoid();
  }
};

#if __cpp_concepts >= 202002L
template<class T>
concept IsMonoid = requires (T m) {
  { m * m } -> std::same_as<T>;
};
#endif

template<class T, T (*op)(T, T), T (*e)()>
struct CommutativeMonoid : public Monoid<T, op, e> {
    using Base = Monoid<T, op, e>;
    CommutativeMonoid(T x=e()) : Base(x) {}
    CommutativeMonoid operator+(const CommutativeMonoid& rhs) const {
        return CommutativeMonoid(*this * rhs);
    }
};

#if __cpp_concepts >= 202002L
template<class T>
concept IsCommutativeMonoid = requires (T m) {
  { m + m } -> std::same_as<T>;
};
#endif

template<class S, class F, S (*mapping)(F, S), F (*composition)(F, F), F (*id)()>
struct OperatorMonoid {
    F f;
    OperatorMonoid() : f(id()) {}
    template<class... Args>
    OperatorMonoid(Args&&... args) : f(std::forward<Args>(args)...) {}
    OperatorMonoid& operator*=(const OperatorMonoid& rhs) {
        f = composition(rhs.f, f);
        return *this;
    }
    S act(const S& s) const {
        return mapping(f, s);
    }
};

struct VoidOperatorMonoid {
    VoidOperatorMonoid() {}
    VoidOperatorMonoid& operator*=(const VoidOperatorMonoid& rhs) {
        return *this;
    }
    template<class T>
    T act(const T& s) const {
        return s;
    }
};

#if __cpp_concepts >= 202002L
template<class F, class S>
concept IsOperatorMonoid = requires (F f, S s) {
    { f *= f } -> std::same_as<F&>;
    { f.act(s) } -> std::same_as<S>;
};
#endif
#line 5 "include/mtl/segment_hld.hpp"
#include <cassert>

template<typename Node>
class SegmentHldBase {
 public:
  using monoid_type = typename Node::monoid_type;
 protected:
  int n_;
  std::vector<Node> tree_;
  std::vector<int> target_;
 public:
  explicit SegmentHldBase(const Hld& tree) : n_(tree.n), target_(n_) {
    std::vector<long long> cw(n_+1);
    for (int i = 0; i < n_; i++) {
      int u = tree.rev[i];
      auto w = tree.size[u];
      if (!tree.edge[u].empty() and tree.edge[u][0] != tree.par[u])
        w -= tree.size[tree.edge[u][0]];
      cw[i+1] = cw[i] + w;
    }
    tree_.reserve(n_*2);
    tree_.resize(1);
    tree_[0].l = 0;
    tree_[0].r = n_;
    for (int i = 0; i < (int)tree_.size(); i++) {
      if (tree_[i].size() == 1) {
        target_[tree_[i].l] = i;
        continue;
      }
      auto l = tree_[i].l;
      auto r = tree_[i].r;
      auto mid = upper_bound(cw.begin()+l, cw.begin()+r, (cw[r]+cw[l]+1)/2);
      assert(cw.begin()+l != mid);
      if (*std::prev(mid)-cw[l] > cw[r]-*mid)
        --mid;
      int m = mid-cw.begin();
      if (l < m) {
        tree_[i].lc = tree_.size();
        tree_.emplace_back();
        tree_.back().l = l;
        tree_.back().r = m;
        tree_.back().p = i;
      }
      if (m < r) {
        tree_[i].rc = tree_.size();
        tree_.emplace_back();
        tree_.back().l = m;
        tree_.back().r = r;
        tree_.back().p = i;
      }
    }
  }
  template<typename InputIt>
  explicit SegmentHldBase(const Hld& tree, InputIt begin, InputIt end) : SegmentHldBase(tree) {
    using iterator_value_type = typename std::iterator_traits<InputIt>::value_type;
    static_assert(std::is_convertible<iterator_value_type, monoid_type>::value, 
                  "SegmentHldBaseInputIt must be convertible to Monoid");
    int i = 0;
    for (auto it = begin; it != end; ++it, ++i) {
      tree_[target_[i]].set(monoid_type(*it));
    }
    for (int i = (int)tree_.size()-1; i >= 0; i--) {
      if (tree_[i].size() == 1) continue;
      tree_[i].take(tree_[tree_[i].lc], tree_[tree_[i].rc]);
    }
  }
};

template<typename M>
struct SegmentHldNode {
  using monoid_type = M;
  int l,r,p=-1,lc=-1,rc=-1;
  monoid_type m, rm;
  int size() const {
    return r-l;
  }
  void set(const monoid_type& monoid) {
    m = rm = monoid;
  }
  void take(const SegmentHldNode& lhs, const SegmentHldNode& rhs) {
    m = lhs.m * rhs.m;
    rm = rhs.rm * lhs.rm;
  }
};
template<
#if __cpp_concepts >= 202002L
  IsMonoid
#else
  class
#endif
    M>
class SegmentHld : private SegmentHldBase<SegmentHldNode<M>> {
 public:
  using monoid_type = M; 
 private:
  using Node = SegmentHldNode<M>;
  using Base = SegmentHldBase<Node>;
  using Base::n_;
  using Base::tree_;
  using Base::target_;
 public:
  explicit SegmentHld(const Hld& tree) : Base(tree) {}
  template<typename InputIt>
  explicit SegmentHld(const Hld& tree, InputIt begin, InputIt end) : Base(tree, begin, end) {}
  const monoid_type& get(int index) const {
    return tree_[target_[index]].m;
  }
  const monoid_type& get_reversed(int index) const {
    return tree_[target_[index]].rm;
  }
  template<class... Args>
  void set(int index, Args&&... args) {
    int i = target_[index];
    tree_[i].set(M(std::forward<Args>(args)...));
    i = tree_[i].p;
    while (i != -1) {
      auto lc = tree_[i].lc, rc = tree_[i].rc;
      tree_[i].take(tree_[lc], tree_[rc]);
      i = tree_[i].p;
    }
  }
  M query(int l, int r) const {
    return _query<0>(l,r,0);
  }
  M reverse_query(int l, int r) const {
    return _query<1>(l,r,0);
  }
 private:
  template<bool Reverse>
  M _query(int l, int r, int u) const {
    if (u == -1)
      return M();
    auto _l = tree_[u].l, _r = tree_[u].r;
    if (_r <= l or r <= _l)
      return M();
    if (l <= _l and _r <= r) {
      if constexpr (!Reverse)
        return tree_[u].m;
      else
        return tree_[u].rm;
    }
    auto lc = tree_[u].lc, rc = tree_[u].rc;
    if constexpr (!Reverse)
      return _query<0>(l, r, lc) * _query<0>(l, r, rc);
    else
      return _query<1>(l, r, rc) * _query<1>(l, r, lc);
  }
};


template<typename M, typename A>
struct LazySegmentHldNode : SegmentHldNode<M> {
  using operator_monoid_type = A;
  A a;
};
template<typename M, typename A>
#if __cpp_concepts >= 202002L
requires IsMonoid<M> && IsOperatorMonoid<A, M>
#endif
class LazySegmentHld : private SegmentHldBase<LazySegmentHldNode<M,A>> {
 public:
  using monoid_type = M;
  using operator_monoid_type = A;
 private:
  using Node = LazySegmentHldNode<M,A>;
  using Base = SegmentHldBase<Node>;
  using Base::n_;
  using Base::tree_;
  using Base::target_;
 public:
  explicit LazySegmentHld(const Hld& tree) : Base(tree) {}
  template<typename InputIt>
  explicit LazySegmentHld(const Hld& tree, InputIt begin, InputIt end) : Base(tree, begin, end) {}
 private:
  inline void _propagate(int u) {
    auto& n = tree_[u];
    auto& a = n.a;
    if (!a()) return;
    n.m = a.act(n.m);
    n.rm = a.act(n.rm);
    if (n.size() > 1) {
      tree_[n.lc].a *= a;
      tree_[n.rc].a *= a;
    }
    n.a = A();
  }
 public:
  template<typename T>
  void set(int index, T&& v) {
    std::vector<int> ids;
    int u = target_[index];
    ids.push_back(u);
    u = tree_[u].p;
    while (u != -1) {
      ids.push_back(u);
      u = tree_[u].p;
    }
    for (int i = (int)ids.size()-1; i >= 0; i--) {
      _propagate(ids[i]);
    }
    tree_[ids[0]].set(monoid_type(std::forward<T>(v)));
    for (int i = 1; i < ids.size(); i++) {
      u = ids[i];
      auto lc = tree_[u].lc, rc = tree_[u].rc;
      auto ac = lc ^ rc ^ ids[i-1];
      _propagate(ac);
      tree_[u].take(tree_[lc], tree_[rc]);
    }
  }
  M query(int l, int r) {
    return _query<0>(l,r,0);
  }
  M reverse_query(int l, int r) {
    return _query<1>(l,r,0);
  }
 private:
  template<bool Reverse>
  M _query(int l, int r, int u) {
    if (u == -1)
      return M();
    auto _l = tree_[u].l, _r = tree_[u].r;
    if (_r <= l or r <= _l)
      return M();
    _propagate(u);
    if (l <= _l and _r <= r) {
      if constexpr (!Reverse)
        return tree_[u].m;
      else
        return tree_[u].rm;
    } else {
      if constexpr (!Reverse)
        return _query<0>(l, r, tree_[u].lc) * _query<0>(l, r, tree_[u].rc);
      else
        return _query<1>(l, r, tree_[u].rc) * _query<1>(l, r, tree_[u].lc);
    }
  }
 public:
  template<typename T>
  void update(int l, int r, const T& v) {
    _update(l, r, v, 0);
  }
 private:
  template<typename T>
  void _update(int l, int r, const T& v, int u) {
    if (u == -1)
      return;
    auto _l = tree_[u].l, _r = tree_[u].r;
    if (_r <= l or r <= _l) {
      _propagate(u);
    } else if (l <= _l and _r <= r) {
      tree_[u].a *= v;
      _propagate(u);
    } else {
      _propagate(u);
      if (tree_[u].size() > 1) {
        auto lc = tree_[u].lc, rc = tree_[u].rc;
        _update(l, r, v, lc);
        _update(l, r, v, rc);
        tree_[u].take(tree_[lc], tree_[rc]);
      }
    }
  }
};
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#line 4 "include/mtl/bit_manip.hpp"
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif

namespace bm {

/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
  x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
  x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
  x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
  return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
  return std::popcount(x);
#else
  return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
  return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_zero(x);
#else
  return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
  return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_one(x);
#else
  return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
  return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
  return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
  constexpr uint64_t hmask = 0x8080808080808080ull;
  constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
  auto a = x & hmask;
  auto b = x & lmask;
  b = hmask - b;
  b = ~b;
  auto c = (a | b) & hmask;
  c *= 0x0002040810204081ull;
  return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
  uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
  return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
  if (x & 0x80)
    return 8;
  uint64_t p = uint64_t(x) * 0x0101010101010101ull;
  p -= 0x8040201008040201ull;
  p = ~p & 0x8080808080808080ull;
  p = (p >> 7) * 0x0101010101010101ull;
  p >>= 56;
  return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
  assert(x != 0);
  return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
  assert(x != 0);
  return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_zero(x);
#else
  if (x == 0)
    return 64;
  auto i = mssb8(summary(x));
  auto j = mssb8(bextr(x, 8 * i, 8));
  return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
  return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_one(x);
#else
  return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
  return clo(x);
}

inline constexpr unsigned clz8(uint8_t x) {
  return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
  x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
  x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
  x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
  x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
  x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
  x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
  return x;
}

/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::has_single_bit(x);
#else
  return x != 0 && (x & (x - 1)) == 0;
#endif
}

/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::bit_width(x);
#else
  return 64 - clz(x);
#endif
}

/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_ceil(x);
#else
  if (x == 0) return 1;
  return 1ull << bit_width(x - 1);
#endif
}

/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_floor(x);
#else
  if (x == 0) return 0;
  return 1ull << (bit_width(x) - 1);
#endif
}

} // namespace bm
#line 3 "include/mtl/modular.hpp"
#include <iostream>
#line 5 "include/mtl/modular.hpp"

template <int MOD>
class Modular {
 private:
  unsigned int val_;

 public:
  static constexpr unsigned int mod() { return MOD; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr Modular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr Modular& operator+=(Modular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular operator-() const { return {mod() - val_}; }
  constexpr Modular& operator-=(Modular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular& operator*=(Modular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr Modular pow(long long p) const {
    assert(p >= 0);
    Modular t = 1;
    Modular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr Modular pow(Modular x, long long p) {
    return x.pow(p);
  }
  constexpr Modular inv() const { return pow(mod()-2); }
  constexpr Modular& operator/=(Modular x) { return *this *= x.inv(); }
  constexpr Modular operator+(Modular x) const { return Modular(*this) += x; }
  constexpr Modular operator-(Modular x) const { return Modular(*this) -= x; }
  constexpr Modular operator*(Modular x) const { return Modular(*this) *= x; }
  constexpr Modular operator/(Modular x) const { return Modular(*this) /= x; }
  constexpr Modular& operator++() { return *this += 1; }
  constexpr Modular operator++(int) { Modular c = *this; ++(*this); return c; }
  constexpr Modular& operator--() { return *this -= 1; }
  constexpr Modular operator--(int) { Modular c = *this; --(*this); return c; }

  constexpr bool operator==(Modular x) const { return val() == x.val(); }
  constexpr bool operator!=(Modular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr Modular sqrt() const {
    if (!is_square()) 
      throw std::runtime_error("not square");
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= Modular(2).pow((mod()-1)/4);
      return x;
    } else {
      Modular d = 2;
      while (d.is_square())
        d += 1;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      Modular dt = 1;
      Modular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const Modular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, Modular& x) {
    return is >> x.val_;
  }

};

using Modular998244353 = Modular<998244353>;
using Modular1000000007 = Modular<(int)1e9+7>;

template<int Id=0>
class DynamicModular {
 private:
  static unsigned int mod_;
  unsigned int val_;

 public:
  static unsigned int mod() { return mod_; }
  static void set_mod(unsigned int m) { mod_ = m; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr DynamicModular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr DynamicModular& operator+=(DynamicModular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular operator-() const { return {mod() - val_}; }
  constexpr DynamicModular& operator-=(DynamicModular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular& operator*=(DynamicModular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr DynamicModular pow(long long p) const {
    assert(p >= 0);
    DynamicModular t = 1;
    DynamicModular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr DynamicModular pow(DynamicModular x, long long p) {
    return x.pow(p);
  }
  // TODO: implement when mod is not prime
  constexpr DynamicModular inv() const { return pow(mod()-2); }
  constexpr DynamicModular& operator/=(DynamicModular x) { return *this *= x.inv(); }
  constexpr DynamicModular operator+(DynamicModular x) const { return DynamicModular(*this) += x; }
  constexpr DynamicModular operator-(DynamicModular x) const { return DynamicModular(*this) -= x; }
  constexpr DynamicModular operator*(DynamicModular x) const { return DynamicModular(*this) *= x; }
  constexpr DynamicModular operator/(DynamicModular x) const { return DynamicModular(*this) /= x; }
  constexpr DynamicModular& operator++() { return *this += 1; }
  constexpr DynamicModular operator++(int) { DynamicModular c = *this; ++(*this); return c; }
  constexpr DynamicModular& operator--() { return *this -= 1; }
  constexpr DynamicModular operator--(int) { DynamicModular c = *this; --(*this); return c; }

  constexpr bool operator==(DynamicModular x) const { return val() == x.val(); }
  constexpr bool operator!=(DynamicModular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return val() == 0 or pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr DynamicModular sqrt() const {
    // assert mod is prime
    if (!is_square()) 
      throw std::runtime_error("not square");
    if (val() < 2)
      return val();
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= DynamicModular(2).pow((mod()-1)/4);
      return x;
    } else {
      DynamicModular d = 2;
      while (d.is_square())
        ++d;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      DynamicModular dt = 1;
      DynamicModular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const DynamicModular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, DynamicModular& x) {
    return is >> x.val_;
  }

};
template<int Id>
unsigned int DynamicModular<Id>::mod_;

#line 264 "include/mtl/modular.hpp"

template<class ModInt>
struct ModularUtil {
  static constexpr int mod = ModInt::mod();
  static struct inv_table {
    std::vector<ModInt> tb{0,1};
    inv_table() : tb({0,1}) {}
  } inv_;
  void set_inv(int n) {
    int m = inv_.tb.size();
    if (m > n) return;
    inv_.tb.resize(n+1);
    for (int i = m; i < n+1; i++)
      inv_.tb[i] = -inv_.tb[mod % i] * (mod / i);
  }
  ModInt& inv(int i) {
    set_inv(i);
    return inv_.tb[i];
  }
};
template<class ModInt>
typename ModularUtil<ModInt>::inv_table ModularUtil<ModInt>::inv_;

#include <array>

namespace math {

constexpr int mod_pow_constexpr(int x, int p, int m) {
  long long t = 1;
  long long u = x;
  while (p) {
    if (p & 1) {
      t *= u;
      t %= m;
    }
    u *= u;
    u %= m;
    p >>= 1;
  }
  return (int) t;
}

constexpr int primitive_root_constexpr(int m) {
  if (m == 2) return 1;
  if (m == 167772161) return 3;
  if (m == 469762049) return 3;
  if (m == 754974721) return 11;
  if (m == 880803841) return 26;
  if (m == 998244353) return 3;

  std::array<int, 20> divs{};
  int cnt = 0;
  int x = m-1;
  if (x % 2 == 0) {
    divs[cnt++] = 2;
    x >>= bm::ctz(x);
  }
  for (int d = 3; d*d <= x; d += 2) {
    if (x % d == 0) {
      divs[cnt++] = d;
      while (x % d == 0)
        x /= d;
    }
  }
  if (x > 1) divs[cnt++] = x;
  for (int g = 2; g < m; g++) {
    bool ok = true;
    for (int i = 0; i < cnt; i++) {
      if (mod_pow_constexpr(g, (m-1) / divs[i], m) == 1) {
        ok = false;
        break;
      }
    }
    if (ok) return g;
  }
  return -1;
}

template<int m>
constexpr int primitive_root = primitive_root_constexpr(m);

}
#line 5 "test/yosupo/vertex_set_path_composite.test.cpp"
#include <bits/stdc++.h>
using namespace std;

using mint = Modular<998244353>;

struct Composite {
    mint a, b;
    Composite(mint a=1, mint b=0):a(a),b(b) {}
    Composite(pair<int,int> p):a(p.first),b(p.second) {}
    Composite operator*(const Composite& o) const { 
        return Composite(a*o.a, b*o.a + o.b);
    }
    mint eval(mint x) const {
        return a*x+b;
    }
};

int main() {
    int n,q; cin>>n>>q;
    vector<pair<int,int>> C(n);
    for (int i = 0; i < n; i++) {
        int a,b; cin>>a>>b;
        C[i] = {a,b};
    }
    Hld T(n);
    for (int i = 0; i < n-1; i++) {
        int u,v; cin>>u>>v;
        T.add_edge(u,v);
    } 
    T.build();
    decltype(C) D(n);
    for (int i = 0; i < n; i++) 
        D[T.in[i]] = C[i];
    SegmentHld<Composite> path_sum(T, D.begin(), D.end());
    for (int i = 0; i < q; i++) {
        int t; cin>>t;
        if (t == 0) {
            int p,c,d; cin>>p>>c>>d;
            T.set(p, [&](auto i, auto v) {path_sum.set(i, v);}, Composite(c, d));
        } else {
            int u,v,x; cin>>u>>v>>x;
            auto lq = [&](int l, int r) { 
                return path_sum.query(l, r); 
            };
            auto rq = [&](int l, int r) { 
                return path_sum.reverse_query(l, r); 
            };
            auto ret = T.query(u,v,lq,rq).eval(x);
            cout << ret << endl;
        }
    }
}

Test cases

Env Name Status Elapsed Memory
g++ almost_line_00 :heavy_check_mark: AC 908 ms 60 MB
g++ almost_line_01 :heavy_check_mark: AC 787 ms 66 MB
g++ example_00 :heavy_check_mark: AC 6 ms 3 MB
g++ example_01 :heavy_check_mark: AC 5 ms 3 MB
g++ line_00 :heavy_check_mark: AC 705 ms 78 MB
g++ line_01 :heavy_check_mark: AC 670 ms 84 MB
g++ long-path-decomposition_killer_00 :heavy_check_mark: AC 649 ms 40 MB
g++ max_random_00 :heavy_check_mark: AC 798 ms 40 MB
g++ max_random_01 :heavy_check_mark: AC 817 ms 40 MB
g++ max_random_02 :heavy_check_mark: AC 781 ms 40 MB
g++ random_00 :heavy_check_mark: AC 538 ms 27 MB
g++ random_01 :heavy_check_mark: AC 651 ms 31 MB
g++ random_02 :heavy_check_mark: AC 407 ms 13 MB
g++ small_00 :heavy_check_mark: AC 9 ms 4 MB
g++ small_01 :heavy_check_mark: AC 8 ms 3 MB
g++ small_02 :heavy_check_mark: AC 8 ms 3 MB
g++ small_03 :heavy_check_mark: AC 9 ms 4 MB
g++ small_04 :heavy_check_mark: AC 7 ms 4 MB
clang++ almost_line_00 :heavy_check_mark: AC 676 ms 43 MB
clang++ almost_line_01 :heavy_check_mark: AC 686 ms 44 MB
clang++ example_00 :heavy_check_mark: AC 6 ms 3 MB
clang++ example_01 :heavy_check_mark: AC 5 ms 3 MB
clang++ line_00 :heavy_check_mark: AC 723 ms 47 MB
clang++ line_01 :heavy_check_mark: AC 635 ms 48 MB
clang++ long-path-decomposition_killer_00 :heavy_check_mark: AC 642 ms 40 MB
clang++ max_random_00 :heavy_check_mark: AC 884 ms 40 MB
clang++ max_random_01 :heavy_check_mark: AC 981 ms 40 MB
clang++ max_random_02 :heavy_check_mark: AC 1058 ms 40 MB
clang++ random_00 :heavy_check_mark: AC 558 ms 27 MB
clang++ random_01 :heavy_check_mark: AC 693 ms 31 MB
clang++ random_02 :heavy_check_mark: AC 372 ms 13 MB
clang++ small_00 :heavy_check_mark: AC 9 ms 4 MB
clang++ small_01 :heavy_check_mark: AC 8 ms 3 MB
clang++ small_02 :heavy_check_mark: AC 8 ms 3 MB
clang++ small_03 :heavy_check_mark: AC 9 ms 4 MB
clang++ small_04 :heavy_check_mark: AC 7 ms 4 MB
Back to top page