matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:heavy_check_mark: test/yosupo/range_affine_range_sum-splay_tree.test.cpp

Code

#define PROBLEM "https://judge.yosupo.jp/problem/range_affine_range_sum"
#include "include/mtl/splay_tree_list.hpp"
#include "include/mtl/modular.hpp"
#include <bits/stdc++.h>

using namespace std;
using ll = long long;
constexpr ll MOD = 998244353;
using mint = Modular<MOD>;

struct Sum {
  mint a=0;
  int sz=0;
  Sum operator*(const Sum& r) const {
    return {a+r.a, sz+r.sz};
  }
};

struct Affine {
  mint b=1, c=0;
  Affine& operator*=(const Affine& r) {
    b *= r.b;
    c = c*r.b + r.c;
    return *this;
  }
  Sum act(const Sum& a) const {
    return {b*a.a + c*a.sz, a.sz};
  }
};

int main() {
  cin.tie(nullptr); ios::sync_with_stdio(false);

  int N,Q; cin>>N>>Q;
  vector<Sum> A(N); 
  for (auto& a : A) {
    cin>>a.a;
    a.sz = 1;
  }
  SplayTreeList<Sum, Affine> rsq(A.begin(), A.end());

  for (int q = 0; q < Q; q++) {
    int t; cin>>t;
    if (t == 0) {
      int l,r,b,c; cin>>l>>r>>b>>c;
      rsq.update(l,r, {b,c});
    } else if (t == 1) {
      int l,r; cin>>l>>r;
      auto ans = rsq.prod(l,r);
      cout << ans.a << endl;
    }
  }

  return 0;
}
#line 1 "test/yosupo/range_affine_range_sum-splay_tree.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/range_affine_range_sum"
#line 2 "include/mtl/splay_tree.hpp"
#include <memory>
#include <cassert>

template<class NodeType>
struct SplayTreeNodeBase {
    using node_type = NodeType;
    using node_shared = std::shared_ptr<node_type>;
    using node_weak = std::weak_ptr<node_type>;
    node_shared l,r;
    node_weak p;
    bool rev;
    bool is_root() const {
        return p.expired() || (p.lock()->l.get() != this && p.lock()->r.get() != this); 
    }
};
template<class T>
struct SplayTreeNodeTraits {
    using node_type = typename T::node_type;
    using node_shared = typename T::node_shared;
    using node_weak = typename T::node_weak;
};

template<class Node>
struct SplayTreeBase {
    using node_traits = SplayTreeNodeTraits<Node>;
    using node_type = typename node_traits::node_type;
    using node_shared = typename node_traits::node_shared;
    using node_weak = typename node_traits::node_weak;
    SplayTreeBase() = default;

    void rotate_left(const node_shared& u) const {
        auto p = u->p.lock(), q = p->p.lock();
        p->r = u->l;
        if (p->r)
            p->r->p = p;
        u->l = p;
        p->p = u;
        u->p = q;
        if (q) {
            if (q->l == p)
                q->l = u;
            else if (q->r == p)
                q->r = u;
        }
    }
    void rotate_right(const node_shared& u) const {
        auto p = u->p.lock(), q = p->p.lock();
        p->l = u->r;
        if (p->l)
            p->l->p = p;
        u->r = p;
        p->p = u;
        u->p = q;
        if (q) {
            if (q->l == p)
                q->l = u;
            else if (q->r == p)
                q->r = u;
        }
    }
    virtual void reverse_prod(const node_shared& u) const {}
    virtual void propagate(const node_shared& u) const {}
    virtual void aggregate(const node_shared& u) const {}
    void splay(const node_shared& u) const {
        if (u->is_root()) {
            this->propagate(u);
            this->aggregate(u);
            return;
        }
        do {
            assert(u);
            auto p = u->p.lock();
            assert(p);
            if (p->is_root()) {
                this->propagate(p);
                this->propagate(u);
                if (p->l == u)
                    rotate_right(u);
                else if (p->r == u)
                    rotate_left(u);
                else throw "invalid tree";
                this->aggregate(p);
                this->aggregate(u);
            } else {
                auto q = p->p.lock();
                this->propagate(q);
                this->propagate(p);
                this->propagate(u);
                if (q->l == p) {
                    if (p->l == u) { // zig-zig
                        rotate_right(p);
                        rotate_right(u);
                        this->aggregate(q);
                        this->aggregate(p);
                    } else if (p->r == u) { // zig-zag
                        rotate_left(u);
                        rotate_right(u);
                        this->aggregate(p);
                        this->aggregate(q);
                    } else throw "invalid tree";
                } else if (q->r == p) { 
                    if (p->r == u) { // zig-zig
                        rotate_left(p);
                        rotate_left(u);
                        this->aggregate(q);
                        this->aggregate(p);
                    } else if (p->l == u) { // zig-zag
                        rotate_right(u);
                        rotate_left(u);
                        this->aggregate(p);
                        this->aggregate(q);
                    } else throw "invalid tree";
                }
                this->aggregate(u);
            }
        } while (!u->is_root());
    }

};
#line 2 "include/mtl/monoid.hpp"
#include <utility>
#if __cpp_concepts >= 202002L
#include <concepts>
#endif

template<class T, T (*op)(T, T), T (*e)()>
struct Monoid {
  T x;
  Monoid() : x(e()) {}
  template<class... Args>
  Monoid(Args&&... args) : x(std::forward<Args>(args)...) {}
  Monoid operator*(const Monoid& rhs) const {
    return Monoid(op(x, rhs.x));
  }
  const T& val() const {
    return x;
  }
};

struct VoidMonoid {
  VoidMonoid() {}
  VoidMonoid operator*(const VoidMonoid& rhs) const {
    return VoidMonoid();
  }
};

#if __cpp_concepts >= 202002L
template<class T>
concept IsMonoid = requires (T m) {
  { m * m } -> std::same_as<T>;
};
#endif

template<class T, T (*op)(T, T), T (*e)()>
struct CommutativeMonoid : public Monoid<T, op, e> {
    using Base = Monoid<T, op, e>;
    CommutativeMonoid(T x=e()) : Base(x) {}
    CommutativeMonoid operator+(const CommutativeMonoid& rhs) const {
        return CommutativeMonoid(*this * rhs);
    }
};

#if __cpp_concepts >= 202002L
template<class T>
concept IsCommutativeMonoid = requires (T m) {
  { m + m } -> std::same_as<T>;
};
#endif

template<class S, class F, S (*mapping)(F, S), F (*composition)(F, F), F (*id)()>
struct OperatorMonoid {
    F f;
    OperatorMonoid() : f(id()) {}
    template<class... Args>
    OperatorMonoid(Args&&... args) : f(std::forward<Args>(args)...) {}
    OperatorMonoid& operator*=(const OperatorMonoid& rhs) {
        f = composition(rhs.f, f);
        return *this;
    }
    S act(const S& s) const {
        return mapping(f, s);
    }
};

struct VoidOperatorMonoid {
    VoidOperatorMonoid() {}
    VoidOperatorMonoid& operator*=(const VoidOperatorMonoid& rhs) {
        return *this;
    }
    template<class T>
    T act(const T& s) const {
        return s;
    }
};

#if __cpp_concepts >= 202002L
template<class F, class S>
concept IsOperatorMonoid = requires (F f, S s) {
    { f *= f } -> std::same_as<F&>;
    { f.act(s) } -> std::same_as<S>;
};
#endif
#line 4 "include/mtl/splay_tree_list.hpp"

template<class M, class O>
struct SplayTreeListNode : SplayTreeNodeBase<SplayTreeListNode<M,O>> {
    size_t sum;
    M m, prod, rprod;
    O f;
    using Base = SplayTreeNodeBase<SplayTreeListNode<M,O>>;
    template<class... Args>
    SplayTreeListNode(Args&&... args) 
        : Base(), sum(1), m(std::forward<Args>(args)...), prod(m), rprod(m), f() {}
};
template<class M, class O=VoidOperatorMonoid>
#if __cpp_concepts >= 202002L
requires IsMonoid<M> && IsOperatorMonoid<O, M>
#endif
struct SplayTreeList : SplayTreeBase<SplayTreeListNode<M, O>> {
    using node_type = SplayTreeListNode<M, O>;
    using base = SplayTreeBase<node_type>;
    using monoid_type = M;
    using operator_monoid_type = O;
    using node_shared = typename SplayTreeNodeTraits<node_type>::node_shared;
    node_shared root;

    SplayTreeList() : base(), root(nullptr) {}
    template<class InputIt>
    node_shared _dfs_init(InputIt first, InputIt last) {
        if (first == last) return nullptr;
        auto n = std::distance(first, last);
        auto mid = std::next(first, n / 2);
        auto u = std::make_shared<node_type>(*mid);
        u->l = _dfs_init(first, mid);
        if (u->l) u->l->p = u;
        u->r = _dfs_init(std::next(mid), last);
        if (u->r) u->r->p = u;
        this->aggregate(u);
        return u;
    }
    template<class InputIt>
    SplayTreeList(InputIt first, InputIt last) : SplayTreeList() {
        if (first == last) return;
        using iterator_category = typename std::iterator_traits<InputIt>::iterator_category;
        if constexpr (std::is_base_of<iterator_category, std::random_access_iterator_tag>::value) {
            root = _dfs_init(first, last);
        } else {
            auto it = first;
            root = std::make_shared<node_type>(*it);
            ++it;
            for (; it != last; ++it) {
                auto u = std::make_shared<node_type>(*it);
                u->l = root;
                root->p = u;
                root = u;
                this->aggregate(u);
            }
        }
    }
    void reverse_prod(const node_shared& u) const override {
        std::swap(u->prod, u->rprod);
    }
    void propagate(const node_shared& u) const override {
        if (u->l) {
            u->l->m = u->f.act(u->l->m);
            u->l->prod = u->f.act(u->l->prod);
            u->l->rprod = u->f.act(u->l->rprod);
            u->l->f *= u->f;
        }
        if (u->r) {
            u->r->m = u->f.act(u->r->m);
            u->r->prod = u->f.act(u->r->prod);
            u->r->rprod = u->f.act(u->r->rprod);
            u->r->f *= u->f;
        }
        if (u->rev) {
            std::swap(u->l, u->r);
            if (u->l) {
                u->l->rev ^= true;
                reverse_prod(u->l);
            }
            if (u->r) {
                u->r->rev ^= true;
                reverse_prod(u->r);
            }
            u->rev = false;
        }
        u->f = operator_monoid_type();
    }
    void aggregate(const node_shared& u) const override {
        u->sum = 1;
        u->prod = u->m;
        u->rprod = u->m;
        if (u->l) {
            u->sum += u->l->sum;
            u->prod = u->l->prod * u->prod;
            u->rprod = u->rprod * u->l->rprod;
        }
        if (u->r) {
            u->sum += u->r->sum;
            u->prod = u->prod * u->r->prod;
            u->rprod = u->r->rprod * u->rprod;
        }
    }
    node_shared kth_element(size_t k) {
        assert(k < root->sum);
        auto u = root;
        while (true) {
            assert(u);
            auto lp = u->l ? u->l->sum : 0;
            auto rp = u->r ? u->r->sum : 0;
            // cerr<<u->m.first<<' '<<u->sum<<' '<<lp<<' '<<rp<<endl;
            assert(u->sum == lp+rp+1);
            propagate(u);
            if (!u->l) {
                if (k == 0)
                    break;
                k--;
                u = u->r;
            } else if (u->l->sum == k) {
                break;
            } else if (u->l->sum > k) {
                u = u->l;
            } else {
                k -= u->l->sum + 1;
                u = u->r;
            }
        }
        assert(u);
        base::splay(u);
        root = u;
        return u;
    }
    template<class... Args>
    void insert(size_t i, Args&&... args) {
        auto u = std::make_shared<node_type>(std::forward<Args>(args)...);
        if (i == 0) {
            u->r = root;
            if (root) root->p = u;
            root = u;
            aggregate(u);
            return;
        }
        if (i == root->sum) {
            u->l = root;
            if (root) root->p = u;
            root = u;
            aggregate(u);
            return;
        }
        auto p = kth_element(i);
        u->l = p->l;
        u->r = p;
        if (u->l)
            u->l->p = u;
        u->r->p = u;
        p->l = nullptr;
        root = u;
        aggregate(p);
        aggregate(u);
    }
    void erase(size_t i) {
        assert(i < root->sum);
        auto p = kth_element(i);
        if (i == 0) {
            root = p->r;
            if (root)
                root->p.reset();
            return;
        }
        if (i == root->sum-1) {
            root = p->l;
            if (root)
                root->p.reset();
            return;
        }
        auto r = p->r;
        auto l = p->l;
        root = r;
        root->p.reset();
        r = kth_element(0);
        r->l = l;
        l->p = r;
        aggregate(r);
        // p has no referry so p is deleted outomatically.
    }
    node_shared between(size_t l, size_t r) {
        assert(r <= root->sum);
        if (l == 0) {
            if (r == root->sum)
                return root;
            else
                return kth_element(r)->l;
        }
        if (r == root->sum) {
            return kth_element(l-1)->r;
        }
        auto rp = kth_element(r);
        root = rp->l;
        root->p.reset();
        auto lp = kth_element(l-1);
        rp->l = lp;
        lp->p = rp;
        root = rp;
        aggregate(rp);
        return lp->r;
    }
    monoid_type prod(size_t l, size_t r) {
        return between(l, r)->prod;
    }
    void reverse(size_t l, size_t r) {
        if (l == r) return;
        auto u = between(l, r);
        u->rev ^= true;
        reverse_prod(u);
        base::splay(u);
        root = u;
    }
    template<class... Args>
    void set(size_t i, Args&&... args) {
        auto u = kth_element(i);
        u->m = monoid_type(std::forward<Args>(args)...);
        this->aggregate(u);
    }
    void update(size_t i, const operator_monoid_type& f) {
        auto u = kth_element(i);
        u->m = f.act(u->m);
        this->aggregate(u);
    }
    void update(size_t l, size_t r, const operator_monoid_type& f) {
        assert(l < r);
        auto u = between(l, r);
        u->m = f.act(u->m);
        u->prod = f.act(u->prod);
        u->rprod = f.act(u->rprod);
        u->f *= f;
        base::splay(u);
        root = u;
    }
};
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#line 4 "include/mtl/bit_manip.hpp"
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif

namespace bm {

/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
  x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
  x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
  x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
  return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
  return std::popcount(x);
#else
  return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
  return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_zero(x);
#else
  return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
  return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_one(x);
#else
  return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
  return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
  return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
  constexpr uint64_t hmask = 0x8080808080808080ull;
  constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
  auto a = x & hmask;
  auto b = x & lmask;
  b = hmask - b;
  b = ~b;
  auto c = (a | b) & hmask;
  c *= 0x0002040810204081ull;
  return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
  uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
  return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
  if (x & 0x80)
    return 8;
  uint64_t p = uint64_t(x) * 0x0101010101010101ull;
  p -= 0x8040201008040201ull;
  p = ~p & 0x8080808080808080ull;
  p = (p >> 7) * 0x0101010101010101ull;
  p >>= 56;
  return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
  assert(x != 0);
  return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
  assert(x != 0);
  return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_zero(x);
#else
  if (x == 0)
    return 64;
  auto i = mssb8(summary(x));
  auto j = mssb8(bextr(x, 8 * i, 8));
  return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
  return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_one(x);
#else
  return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
  return clo(x);
}

inline constexpr unsigned clz8(uint8_t x) {
  return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
  x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
  x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
  x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
  x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
  x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
  x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
  return x;
}

/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::has_single_bit(x);
#else
  return x != 0 && (x & (x - 1)) == 0;
#endif
}

/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::bit_width(x);
#else
  return 64 - clz(x);
#endif
}

/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_ceil(x);
#else
  if (x == 0) return 1;
  return 1ull << bit_width(x - 1);
#endif
}

/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_floor(x);
#else
  if (x == 0) return 0;
  return 1ull << (bit_width(x) - 1);
#endif
}

} // namespace bm
#line 3 "include/mtl/modular.hpp"
#include <iostream>
#line 5 "include/mtl/modular.hpp"

template <int MOD>
class Modular {
 private:
  unsigned int val_;

 public:
  static constexpr unsigned int mod() { return MOD; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr Modular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr Modular& operator+=(Modular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular operator-() const { return {mod() - val_}; }
  constexpr Modular& operator-=(Modular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular& operator*=(Modular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr Modular pow(long long p) const {
    assert(p >= 0);
    Modular t = 1;
    Modular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr Modular pow(Modular x, long long p) {
    return x.pow(p);
  }
  constexpr Modular inv() const { return pow(mod()-2); }
  constexpr Modular& operator/=(Modular x) { return *this *= x.inv(); }
  constexpr Modular operator+(Modular x) const { return Modular(*this) += x; }
  constexpr Modular operator-(Modular x) const { return Modular(*this) -= x; }
  constexpr Modular operator*(Modular x) const { return Modular(*this) *= x; }
  constexpr Modular operator/(Modular x) const { return Modular(*this) /= x; }
  constexpr Modular& operator++() { return *this += 1; }
  constexpr Modular operator++(int) { Modular c = *this; ++(*this); return c; }
  constexpr Modular& operator--() { return *this -= 1; }
  constexpr Modular operator--(int) { Modular c = *this; --(*this); return c; }

  constexpr bool operator==(Modular x) const { return val() == x.val(); }
  constexpr bool operator!=(Modular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr Modular sqrt() const {
    if (!is_square()) 
      throw std::runtime_error("not square");
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= Modular(2).pow((mod()-1)/4);
      return x;
    } else {
      Modular d = 2;
      while (d.is_square())
        d += 1;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      Modular dt = 1;
      Modular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const Modular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, Modular& x) {
    return is >> x.val_;
  }

};

using Modular998244353 = Modular<998244353>;
using Modular1000000007 = Modular<(int)1e9+7>;

template<int Id=0>
class DynamicModular {
 private:
  static unsigned int mod_;
  unsigned int val_;

 public:
  static unsigned int mod() { return mod_; }
  static void set_mod(unsigned int m) { mod_ = m; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr DynamicModular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr DynamicModular& operator+=(DynamicModular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular operator-() const { return {mod() - val_}; }
  constexpr DynamicModular& operator-=(DynamicModular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular& operator*=(DynamicModular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr DynamicModular pow(long long p) const {
    assert(p >= 0);
    DynamicModular t = 1;
    DynamicModular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr DynamicModular pow(DynamicModular x, long long p) {
    return x.pow(p);
  }
  // TODO: implement when mod is not prime
  constexpr DynamicModular inv() const { return pow(mod()-2); }
  constexpr DynamicModular& operator/=(DynamicModular x) { return *this *= x.inv(); }
  constexpr DynamicModular operator+(DynamicModular x) const { return DynamicModular(*this) += x; }
  constexpr DynamicModular operator-(DynamicModular x) const { return DynamicModular(*this) -= x; }
  constexpr DynamicModular operator*(DynamicModular x) const { return DynamicModular(*this) *= x; }
  constexpr DynamicModular operator/(DynamicModular x) const { return DynamicModular(*this) /= x; }
  constexpr DynamicModular& operator++() { return *this += 1; }
  constexpr DynamicModular operator++(int) { DynamicModular c = *this; ++(*this); return c; }
  constexpr DynamicModular& operator--() { return *this -= 1; }
  constexpr DynamicModular operator--(int) { DynamicModular c = *this; --(*this); return c; }

  constexpr bool operator==(DynamicModular x) const { return val() == x.val(); }
  constexpr bool operator!=(DynamicModular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return val() == 0 or pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr DynamicModular sqrt() const {
    // assert mod is prime
    if (!is_square()) 
      throw std::runtime_error("not square");
    if (val() < 2)
      return val();
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= DynamicModular(2).pow((mod()-1)/4);
      return x;
    } else {
      DynamicModular d = 2;
      while (d.is_square())
        ++d;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      DynamicModular dt = 1;
      DynamicModular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const DynamicModular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, DynamicModular& x) {
    return is >> x.val_;
  }

};
template<int Id>
unsigned int DynamicModular<Id>::mod_;

#include <vector>

template<class ModInt>
struct ModularUtil {
  static constexpr int mod = ModInt::mod();
  static struct inv_table {
    std::vector<ModInt> tb{0,1};
    inv_table() : tb({0,1}) {}
  } inv_;
  void set_inv(int n) {
    int m = inv_.tb.size();
    if (m > n) return;
    inv_.tb.resize(n+1);
    for (int i = m; i < n+1; i++)
      inv_.tb[i] = -inv_.tb[mod % i] * (mod / i);
  }
  ModInt& inv(int i) {
    set_inv(i);
    return inv_.tb[i];
  }
};
template<class ModInt>
typename ModularUtil<ModInt>::inv_table ModularUtil<ModInt>::inv_;

#include <array>

namespace math {

constexpr int mod_pow_constexpr(int x, int p, int m) {
  long long t = 1;
  long long u = x;
  while (p) {
    if (p & 1) {
      t *= u;
      t %= m;
    }
    u *= u;
    u %= m;
    p >>= 1;
  }
  return (int) t;
}

constexpr int primitive_root_constexpr(int m) {
  if (m == 2) return 1;
  if (m == 167772161) return 3;
  if (m == 469762049) return 3;
  if (m == 754974721) return 11;
  if (m == 880803841) return 26;
  if (m == 998244353) return 3;

  std::array<int, 20> divs{};
  int cnt = 0;
  int x = m-1;
  if (x % 2 == 0) {
    divs[cnt++] = 2;
    x >>= bm::ctz(x);
  }
  for (int d = 3; d*d <= x; d += 2) {
    if (x % d == 0) {
      divs[cnt++] = d;
      while (x % d == 0)
        x /= d;
    }
  }
  if (x > 1) divs[cnt++] = x;
  for (int g = 2; g < m; g++) {
    bool ok = true;
    for (int i = 0; i < cnt; i++) {
      if (mod_pow_constexpr(g, (m-1) / divs[i], m) == 1) {
        ok = false;
        break;
      }
    }
    if (ok) return g;
  }
  return -1;
}

template<int m>
constexpr int primitive_root = primitive_root_constexpr(m);

}
#line 4 "test/yosupo/range_affine_range_sum-splay_tree.test.cpp"
#include <bits/stdc++.h>

using namespace std;
using ll = long long;
constexpr ll MOD = 998244353;
using mint = Modular<MOD>;

struct Sum {
  mint a=0;
  int sz=0;
  Sum operator*(const Sum& r) const {
    return {a+r.a, sz+r.sz};
  }
};

struct Affine {
  mint b=1, c=0;
  Affine& operator*=(const Affine& r) {
    b *= r.b;
    c = c*r.b + r.c;
    return *this;
  }
  Sum act(const Sum& a) const {
    return {b*a.a + c*a.sz, a.sz};
  }
};

int main() {
  cin.tie(nullptr); ios::sync_with_stdio(false);

  int N,Q; cin>>N>>Q;
  vector<Sum> A(N); 
  for (auto& a : A) {
    cin>>a.a;
    a.sz = 1;
  }
  SplayTreeList<Sum, Affine> rsq(A.begin(), A.end());

  for (int q = 0; q < Q; q++) {
    int t; cin>>t;
    if (t == 0) {
      int l,r,b,c; cin>>l>>r>>b>>c;
      rsq.update(l,r, {b,c});
    } else if (t == 1) {
      int l,r; cin>>l>>r;
      auto ans = rsq.prod(l,r);
      cout << ans.a << endl;
    }
  }

  return 0;
}

Test cases

Env Name Status Elapsed Memory
g++ example_00 :heavy_check_mark: AC 6 ms 3 MB
g++ max_random_00 :heavy_check_mark: AC 4652 ms 70 MB
g++ max_random_01 :heavy_check_mark: AC 5061 ms 70 MB
g++ max_random_02 :heavy_check_mark: AC 4766 ms 70 MB
g++ random_00 :heavy_check_mark: AC 3835 ms 55 MB
g++ random_01 :heavy_check_mark: AC 3931 ms 65 MB
g++ random_02 :heavy_check_mark: AC 2864 ms 10 MB
g++ small_00 :heavy_check_mark: AC 6 ms 3 MB
g++ small_01 :heavy_check_mark: AC 6 ms 3 MB
g++ small_02 :heavy_check_mark: AC 6 ms 3 MB
g++ small_03 :heavy_check_mark: AC 6 ms 3 MB
g++ small_04 :heavy_check_mark: AC 6 ms 3 MB
g++ small_05 :heavy_check_mark: AC 6 ms 3 MB
g++ small_06 :heavy_check_mark: AC 6 ms 3 MB
g++ small_07 :heavy_check_mark: AC 6 ms 3 MB
g++ small_08 :heavy_check_mark: AC 7 ms 3 MB
g++ small_09 :heavy_check_mark: AC 7 ms 3 MB
g++ small_random_00 :heavy_check_mark: AC 10 ms 4 MB
g++ small_random_01 :heavy_check_mark: AC 9 ms 4 MB
clang++ example_00 :heavy_check_mark: AC 6 ms 3 MB
clang++ max_random_00 :heavy_check_mark: AC 4969 ms 70 MB
clang++ max_random_01 :heavy_check_mark: AC 5163 ms 70 MB
clang++ max_random_02 :heavy_check_mark: AC 5327 ms 70 MB
clang++ random_00 :heavy_check_mark: AC 3778 ms 55 MB
clang++ random_01 :heavy_check_mark: AC 3933 ms 65 MB
clang++ random_02 :heavy_check_mark: AC 2541 ms 10 MB
clang++ small_00 :heavy_check_mark: AC 6 ms 3 MB
clang++ small_01 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_02 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_03 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_04 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_05 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_06 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_07 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_08 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_09 :heavy_check_mark: AC 7 ms 3 MB
clang++ small_random_00 :heavy_check_mark: AC 12 ms 4 MB
clang++ small_random_01 :heavy_check_mark: AC 11 ms 4 MB
Back to top page