matsutaku-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub MatsuTaku/matsutaku-library

:warning: test/atcoder/abc116-c-multiple_sequences-dirichlet-optimized.test.cpp

Code

#define PROBLEM "https://atcoder.jp/contests/arc116/tasks/arc116_c"
#define IGNORE
#include "include/mtl/dirichlet.hpp"
#include "include/mtl/modular.hpp"
#include "include/mtl/enumerate.hpp"
#include <bits/stdc++.h>
using namespace std;
using mint = Modular998244353;

int main() {
  int n,m; cin>>n>>m;
  constexpr int max_logn = 30;
  auto primes = Primes(m);
  // Sum of zeta^k for m
  auto f = Identity<mint>(m);
  vector<mint> A(max_logn);
  for (int k = 0; k < max_logn; k++) {
    A[k] = f.second[1];
    f = DirichletConvolveZeta(m, primes, f);
  }
  // Sum of (zeta-1)^k for m
  vector<mint> B(max_logn);
  Enumerate<mint> enm;
  for (int k = 0; k < max_logn; k++) {
    for (int i = 0; i <= k; i++) {
      auto coef = enm.cmb(k, i);
      if ((k-i)%2) coef = -coef;
      B[k] += A[i] * coef;
    }
  }
  // ans = sum_k B[k] binom(n, k)
  mint binom = 1;
  mint ans = 0;
  for (int k = 0; k < max_logn; k++) {
    ans += binom * B[k];
    binom *= mint(n-k) * mint(k+1).inv();
  }
  cout << ans << endl;
}
#line 1 "test/atcoder/abc116-c-multiple_sequences-dirichlet-optimized.test.cpp"
#define PROBLEM "https://atcoder.jp/contests/arc116/tasks/arc116_c"
#define IGNORE
#line 2 "include/mtl/dirichlet.hpp"
#include <vector>
#include <cmath>
#include <cassert>

std::vector<int> EratosthenesSieve(const int n) {
  std::vector<int> p(n+1);
  if (n == 0)
    return p;
  p[1] = 1;
  for (int i = 2; i <= n; i++) {
    if (p[i] == 0) {
      p[i] = i;
      for (int j = i*2; j <= n; j += i) {
        if (p[j] == 0)
          p[j] = i;
      }
    }
  }
  return p;
}

std::vector<int> Primes(const int n) {
  std::vector<int> ps;
  auto era = EratosthenesSieve(n);
  for (int i = 2; i <= n; i++) {
    if (era[i] == i) {
      ps.push_back(i);
    }
  }
  return ps;
}

/* PseudoCode:
 *   D_c(s) = sum c(n) n^{-s} = sum_n sum_{ij=n} a(i)b(j) (ij)^{-s}
 * complexity: O(n log n)
 */
template<typename T>
std::vector<T> DirichletConvolution(const std::vector<T>& a, const std::vector<T>& b) {
  int n = (a.size()-1);
  std::vector<T> c(n+1);
  for (int i = 1; i <= n; i++) {
    int m = n / i;
    for (int j = 1; j <= m; j++)
      c[i * j] += a[i] * b[j];
  }
  return c;
}

/* PseudoCode:
 *   for p in primes:
 *     D_{a,p}(s) = sum_k a(p^k) p^{-ks} (p-part of D_a)
 *     D_b(s) <- D_b(s) D_{a,p}(s)
 * requirements:
 *   - Sequence a should be multinomial.
 *     D_a(s) = prod_p sum_k a(p^k) p^{-ks}
 * complexity: O(n log log n)
 */
template<typename T>
std::vector<T> DirichletMultinomialConvolution(const std::vector<T>& a, const std::vector<T>& b) {
  int n = (a.size()-1);
  auto c = b;
  c.resize(n+1);
  for (int p : Primes(n)) {
    int m = n/p;
    for (int i = m; i >= 1; i--) {
      int u = p * i;
      int q = p, j = i;
      while (true) {
        c[u] += a[q] * c[j];
        if (j % p != 0)
          break;
        q *= p;
        j /= p;
      }
    }
  }
  return c;
}

template<typename T>
std::pair<std::vector<T>, std::vector<T>> Identity(int n) {
  int k = std::pow(n, (double) 2 / 3);
  int l = (n + k - 1) / k;
  std::vector<T> a(k+1, 0), A(l+1, 1);
  a[1] = 1;
  A[0] = 0;
  return make_pair(a, A);
}

template<typename T>
std::pair<std::vector<T>, std::vector<T>> Zeta(int n) {
  int k = std::pow(n, (double) 2 / 3);
  int l = (n + k - 1) / k;
  std::vector<T> a(k+1, 1), A(l+1);
  a[0] = 0;
  for (int i = 1; i <= l; i++)
    A[i] = n / i;
  return make_pair(a, A);
}

template<typename T>
std::pair<std::vector<T>, std::vector<T>> DirichletConvolveOptimal(int N, const std::pair<std::vector<T>, std::vector<T>>& _a, const std::pair<std::vector<T>, std::vector<T>>& _b) {
  const auto &a = _a.first, &A = _a.second, &b = _b.first, &B = _b.second;
  int k = a.size()-1, l = A.size()-1;
  assert(k * l >= N);
  auto Alow = a;
  auto Blow = b;
  for (int i = 1; i <= k; i++)
    Alow[i] += Alow[i-1];
  auto getA = [&](int i) {
    return i <= k ? Alow[i] : A[N / i];
  };
  for (int i = 1; i <= k; i++)
    Blow[i] += Blow[i-1];
  auto getB = [&](int i) {
    return i <= k ? Blow[i] : B[N / i];
  };

  auto c = DirichletConvolution(a, b);

  std::vector<T> C(l+1);
  for (int j = 1; j <= l; j++) {
    int n = N / j;
    int m = sqrt(n);
    for (int i = 1; i <= m; i++) {
      C[j] += a[i] * getB(n / i);
      C[j] += (getA(n / i) - getA(m)) * b[i];
    }
  }
  return std::make_pair(c, C);
}

template<typename T>
std::pair<std::vector<T>, std::vector<T>> DirichletConvolveZeta(int N, const std::vector<int>& primes, const std::pair<std::vector<T>, std::vector<T>>& _a) {
  const auto &a = _a.first, &A = _a.second;
  int k = a.size()-1, l = A.size()-1;
  auto Alow = a;
  for (int i = 1; i <= k; i++)
    Alow[i] += Alow[i-1];
  auto getA = [&](int i) {
    return i <= k ? Alow[i] : A[N / i];
  };

  auto c = a;
  for (int p : primes) {
    int m = k / p;
    for (int i = 1; i <= m; i++) {
      c[p * i] += c[i];
    }
  }

  std::vector<T> C(l+1);
  for (int j = 1; j <= l; j++) {
    int n = N / j;
    int m = std::sqrt(n);
    for (int i = 1; i <= m; i++) {
      C[j] += a[i] * (n / i);
      C[j] += getA(n / i) - getA(m);
    }
  }
  return std::make_pair(c, C);
}
#line 2 "include/mtl/bit_manip.hpp"
#include <cstdint>
#line 4 "include/mtl/bit_manip.hpp"
#if __cplusplus >= 202002L
#ifndef MTL_CPP20
#define MTL_CPP20
#endif
#include <bit>
#endif

namespace bm {

/// Count 1s for each 8 bits
inline constexpr uint64_t popcnt_e8(uint64_t x) {
  x = (x & 0x5555555555555555) + ((x>>1) & 0x5555555555555555);
  x = (x & 0x3333333333333333) + ((x>>2) & 0x3333333333333333);
  x = (x & 0x0F0F0F0F0F0F0F0F) + ((x>>4) & 0x0F0F0F0F0F0F0F0F);
  return x;
}
/// Count 1s
inline constexpr unsigned popcnt(uint64_t x) {
#ifdef MTL_CPP20
  return std::popcount(x);
#else
  return (popcnt_e8(x) * 0x0101010101010101) >> 56;
#endif
}
/// Alias to mtl::popcnt(x)
constexpr unsigned popcount(uint64_t x) {
  return popcnt(x);
}
/// Count trailing 0s. s.t. *11011000 -> 3
inline constexpr unsigned ctz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_zero(x);
#else
  return popcnt((x & (-x)) - 1);
#endif
}
/// Alias to mtl::ctz(x)
constexpr unsigned countr_zero(uint64_t x) {
  return ctz(x);
}
/// Count trailing 1s. s.t. *11011011 -> 2
inline constexpr unsigned cto(uint64_t x) {
#ifdef MTL_CPP20
  return std::countr_one(x);
#else
  return ctz(~x);
#endif
}
/// Alias to mtl::cto(x)
constexpr unsigned countr_one(uint64_t x) {
  return cto(x);
}
inline constexpr unsigned ctz8(uint8_t x) {
  return x == 0 ? 8 : popcnt_e8((x & (-x)) - 1);
}
/// [00..0](8bit) -> 0, [**..*](not only 0) -> 1
inline constexpr uint8_t summary(uint64_t x) {
  constexpr uint64_t hmask = 0x8080808080808080ull;
  constexpr uint64_t lmask = 0x7F7F7F7F7F7F7F7Full;
  auto a = x & hmask;
  auto b = x & lmask;
  b = hmask - b;
  b = ~b;
  auto c = (a | b) & hmask;
  c *= 0x0002040810204081ull;
  return uint8_t(c >> 56);
}
/// Extract target area of bits
inline constexpr uint64_t bextr(uint64_t x, unsigned start, unsigned len) {
  uint64_t mask = len < 64 ? (1ull<<len)-1 : 0xFFFFFFFFFFFFFFFFull;
  return (x >> start) & mask;
}
/// 00101101 -> 00111111 -count_1s-> 6
inline constexpr unsigned log2p1(uint8_t x) {
  if (x & 0x80)
    return 8;
  uint64_t p = uint64_t(x) * 0x0101010101010101ull;
  p -= 0x8040201008040201ull;
  p = ~p & 0x8080808080808080ull;
  p = (p >> 7) * 0x0101010101010101ull;
  p >>= 56;
  return p;
}
/// 00101100 -mask_mssb-> 00100000 -to_index-> 5
inline constexpr unsigned mssb8(uint8_t x) {
  assert(x != 0);
  return log2p1(x) - 1;
}
/// 00101100 -mask_lssb-> 00000100 -to_index-> 2
inline constexpr unsigned lssb8(uint8_t x) {
  assert(x != 0);
  return popcnt_e8((x & -x) - 1);
}
/// Count leading 0s. 00001011... -> 4
inline constexpr unsigned clz(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_zero(x);
#else
  if (x == 0)
    return 64;
  auto i = mssb8(summary(x));
  auto j = mssb8(bextr(x, 8 * i, 8));
  return 63 - (8 * i + j);
#endif
}
/// Alias to mtl::clz(x)
constexpr unsigned countl_zero(uint64_t x) {
  return clz(x);
}
/// Count leading 1s. 11110100... -> 4
inline constexpr unsigned clo(uint64_t x) {
#ifdef MTL_CPP20
  return std::countl_one(x);
#else
  return clz(~x);
#endif
}
/// Alias to mtl::clo(x)
constexpr unsigned countl_one(uint64_t x) {
  return clo(x);
}

inline constexpr unsigned clz8(uint8_t x) {
  return x == 0 ? 8 : 7 - mssb8(x);
}
inline constexpr uint64_t bit_reverse(uint64_t x) {
  x = ((x & 0x00000000FFFFFFFF) << 32) | ((x & 0xFFFFFFFF00000000) >> 32);
  x = ((x & 0x0000FFFF0000FFFF) << 16) | ((x & 0xFFFF0000FFFF0000) >> 16);
  x = ((x & 0x00FF00FF00FF00FF) << 8) | ((x & 0xFF00FF00FF00FF00) >> 8);
  x = ((x & 0x0F0F0F0F0F0F0F0F) << 4) | ((x & 0xF0F0F0F0F0F0F0F0) >> 4);
  x = ((x & 0x3333333333333333) << 2) | ((x & 0xCCCCCCCCCCCCCCCC) >> 2);
  x = ((x & 0x5555555555555555) << 1) | ((x & 0xAAAAAAAAAAAAAAAA) >> 1);
  return x;
}

/// Check if x is power of 2. 00100000 -> true, 00100001 -> false
constexpr bool has_single_bit(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::has_single_bit(x);
#else
  return x != 0 && (x & (x - 1)) == 0;
#endif
}

/// Bit width needs to represent x. 00110110 -> 6
constexpr int bit_width(uint64_t x) noexcept {
#ifdef MTL_CPP20
  return std::bit_width(x);
#else
  return 64 - clz(x);
#endif
}

/// Ceil power of 2. 00110110 -> 01000000
constexpr uint64_t bit_ceil(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_ceil(x);
#else
  if (x == 0) return 1;
  return 1ull << bit_width(x - 1);
#endif
}

/// Floor power of 2. 00110110 -> 00100000
constexpr uint64_t bit_floor(uint64_t x) {
#ifdef MTL_CPP20
  return std::bit_floor(x);
#else
  if (x == 0) return 0;
  return 1ull << (bit_width(x) - 1);
#endif
}

} // namespace bm
#line 3 "include/mtl/modular.hpp"
#include <iostream>
#line 5 "include/mtl/modular.hpp"

template <int MOD>
class Modular {
 private:
  unsigned int val_;

 public:
  static constexpr unsigned int mod() { return MOD; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr Modular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr Modular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr Modular& operator+=(Modular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular operator-() const { return {mod() - val_}; }
  constexpr Modular& operator-=(Modular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr Modular& operator*=(Modular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr Modular pow(long long p) const {
    assert(p >= 0);
    Modular t = 1;
    Modular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr Modular pow(Modular x, long long p) {
    return x.pow(p);
  }
  constexpr Modular inv() const { return pow(mod()-2); }
  constexpr Modular& operator/=(Modular x) { return *this *= x.inv(); }
  constexpr Modular operator+(Modular x) const { return Modular(*this) += x; }
  constexpr Modular operator-(Modular x) const { return Modular(*this) -= x; }
  constexpr Modular operator*(Modular x) const { return Modular(*this) *= x; }
  constexpr Modular operator/(Modular x) const { return Modular(*this) /= x; }
  constexpr Modular& operator++() { return *this += 1; }
  constexpr Modular operator++(int) { Modular c = *this; ++(*this); return c; }
  constexpr Modular& operator--() { return *this -= 1; }
  constexpr Modular operator--(int) { Modular c = *this; --(*this); return c; }

  constexpr bool operator==(Modular x) const { return val() == x.val(); }
  constexpr bool operator!=(Modular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr Modular sqrt() const {
    if (!is_square()) 
      throw std::runtime_error("not square");
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= Modular(2).pow((mod()-1)/4);
      return x;
    } else {
      Modular d = 2;
      while (d.is_square())
        d += 1;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      Modular dt = 1;
      Modular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const Modular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, Modular& x) {
    return is >> x.val_;
  }

};

using Modular998244353 = Modular<998244353>;
using Modular1000000007 = Modular<(int)1e9+7>;

template<int Id=0>
class DynamicModular {
 private:
  static unsigned int mod_;
  unsigned int val_;

 public:
  static unsigned int mod() { return mod_; }
  static void set_mod(unsigned int m) { mod_ = m; }
  template<class T>
  static constexpr unsigned int safe_mod(T v) {
    auto x = (long long)(v%(long long)mod());
    if (x < 0) x += mod();
    return (unsigned int) x;
  }

  constexpr DynamicModular() : val_(0) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(v%mod()) {}
  template<class T,
      std::enable_if_t<
          std::is_integral<T>::value && !std::is_unsigned<T>::value
      > * = nullptr>
  constexpr DynamicModular(T v) : val_(safe_mod(v)) {}

  constexpr unsigned int val() const { return val_; }
  constexpr DynamicModular& operator+=(DynamicModular x) {
    val_ += x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular operator-() const { return {mod() - val_}; }
  constexpr DynamicModular& operator-=(DynamicModular x) {
    val_ += mod() - x.val();
    if (val_ >= mod()) val_ -= mod();
    return *this;
  }
  constexpr DynamicModular& operator*=(DynamicModular x) {
    auto v = (long long) val_ * x.val();
    if (v >= mod()) v %= mod();
    val_ = v;
    return *this;
  }
  constexpr DynamicModular pow(long long p) const {
    assert(p >= 0);
    DynamicModular t = 1;
    DynamicModular u = *this;
    while (p) {
      if (p & 1)
        t *= u;
      u *= u;
      p >>= 1;
    }
    return t;
  }
  friend constexpr DynamicModular pow(DynamicModular x, long long p) {
    return x.pow(p);
  }
  // TODO: implement when mod is not prime
  constexpr DynamicModular inv() const { return pow(mod()-2); }
  constexpr DynamicModular& operator/=(DynamicModular x) { return *this *= x.inv(); }
  constexpr DynamicModular operator+(DynamicModular x) const { return DynamicModular(*this) += x; }
  constexpr DynamicModular operator-(DynamicModular x) const { return DynamicModular(*this) -= x; }
  constexpr DynamicModular operator*(DynamicModular x) const { return DynamicModular(*this) *= x; }
  constexpr DynamicModular operator/(DynamicModular x) const { return DynamicModular(*this) /= x; }
  constexpr DynamicModular& operator++() { return *this += 1; }
  constexpr DynamicModular operator++(int) { DynamicModular c = *this; ++(*this); return c; }
  constexpr DynamicModular& operator--() { return *this -= 1; }
  constexpr DynamicModular operator--(int) { DynamicModular c = *this; --(*this); return c; }

  constexpr bool operator==(DynamicModular x) const { return val() == x.val(); }
  constexpr bool operator!=(DynamicModular x) const { return val() != x.val(); }

  constexpr bool is_square() const {
    return val() == 0 or pow((mod()-1)/2) == 1;
  }
  /**
   * Return x s.t. x * x = a mod p
   * reference: https://zenn.dev/peria/articles/c6afc72b6b003c
  */
  constexpr DynamicModular sqrt() const {
    // assert mod is prime
    if (!is_square()) 
      throw std::runtime_error("not square");
    if (val() < 2)
      return val();
    auto mod_eight = mod() % 8;
    if (mod_eight == 3 || mod_eight == 7) {
      return pow((mod()+1)/4);
    } else if (mod_eight == 5) {
      auto x = pow((mod()+3)/8);
      if (x * x != *this)
        x *= DynamicModular(2).pow((mod()-1)/4);
      return x;
    } else {
      DynamicModular d = 2;
      while (d.is_square())
        ++d;
      auto t = mod()-1;
      int s = bm::ctz(t);
      t >>= s;
      auto a = pow(t);
      auto D = d.pow(t);
      int m = 0;
      DynamicModular dt = 1;
      DynamicModular du = D;
      for (int i = 0; i < s; i++) {
        if ((a*dt).pow(1u<<(s-1-i)) == -1) {
          m |= 1u << i;
          dt *= du;
        }
        du *= du;
      }
      return pow((t+1)/2) * D.pow(m/2);
    }
  }

  friend std::ostream& operator<<(std::ostream& os, const DynamicModular& x) {
    return os << x.val();
  }
  friend std::istream& operator>>(std::istream& is, DynamicModular& x) {
    return is >> x.val_;
  }

};
template<int Id>
unsigned int DynamicModular<Id>::mod_;

#line 264 "include/mtl/modular.hpp"

template<class ModInt>
struct ModularUtil {
  static constexpr int mod = ModInt::mod();
  static struct inv_table {
    std::vector<ModInt> tb{0,1};
    inv_table() : tb({0,1}) {}
  } inv_;
  void set_inv(int n) {
    int m = inv_.tb.size();
    if (m > n) return;
    inv_.tb.resize(n+1);
    for (int i = m; i < n+1; i++)
      inv_.tb[i] = -inv_.tb[mod % i] * (mod / i);
  }
  ModInt& inv(int i) {
    set_inv(i);
    return inv_.tb[i];
  }
};
template<class ModInt>
typename ModularUtil<ModInt>::inv_table ModularUtil<ModInt>::inv_;

#include <array>

namespace math {

constexpr int mod_pow_constexpr(int x, int p, int m) {
  long long t = 1;
  long long u = x;
  while (p) {
    if (p & 1) {
      t *= u;
      t %= m;
    }
    u *= u;
    u %= m;
    p >>= 1;
  }
  return (int) t;
}

constexpr int primitive_root_constexpr(int m) {
  if (m == 2) return 1;
  if (m == 167772161) return 3;
  if (m == 469762049) return 3;
  if (m == 754974721) return 11;
  if (m == 880803841) return 26;
  if (m == 998244353) return 3;

  std::array<int, 20> divs{};
  int cnt = 0;
  int x = m-1;
  if (x % 2 == 0) {
    divs[cnt++] = 2;
    x >>= bm::ctz(x);
  }
  for (int d = 3; d*d <= x; d += 2) {
    if (x % d == 0) {
      divs[cnt++] = d;
      while (x % d == 0)
        x /= d;
    }
  }
  if (x > 1) divs[cnt++] = x;
  for (int g = 2; g < m; g++) {
    bool ok = true;
    for (int i = 0; i < cnt; i++) {
      if (mod_pow_constexpr(g, (m-1) / divs[i], m) == 1) {
        ok = false;
        break;
      }
    }
    if (ok) return g;
  }
  return -1;
}

template<int m>
constexpr int primitive_root = primitive_root_constexpr(m);

}
#line 2 "include/mtl/enumerate.hpp"

#line 5 "include/mtl/enumerate.hpp"

template <typename MODULAR>
class Enumerate {
 public:
  using mint = MODULAR;
 private:
  int max_n_ = 1;
  std::vector<mint> fact_, ifact_;

  void _set_max_n(int n);

 public:
  Enumerate() : fact_({1, 1}), ifact_({1, 1}) {}
  explicit Enumerate(int n) : fact_(std::max(2, n+1)), ifact_(std::max(2, n+1)) {
    fact_[0] = fact_[1] = ifact_[0] = ifact_[1] = 1;
    _set_max_n(n);
  }

  mint cmb(int p, int q) {
    if (p < q) return 0;
    return fact(p) * ifact(q) * ifact(p-q);
  }

  mint prm(int p, int q) {
    if (p < q) return 0;
    return fact(p) * ifact(p-q);
  }

  mint fact(int p) {
    if (p > max_n_)
      _set_max_n(p);
    return fact_[p];
  }
  mint ifact(int p) {
    if (p > max_n_)
      _set_max_n(p);
    return ifact_[p];
  }

};

template<typename MODULAR>
void Enumerate<MODULAR>::_set_max_n(int n) {
  if (n <= max_n_)
    return;
  int nxtn = std::max(max_n_*2, n);
  fact_.resize(nxtn+1);
  ifact_.resize(nxtn+1);
  for (int i = max_n_+1; i <= nxtn; i++) {
    fact_[i] = fact_[i-1] * i;
  }
  ifact_[nxtn] = mint(1) / fact_[nxtn];
  for (int i = nxtn-1; i > max_n_; i--) {
    ifact_[i] = ifact_[i+1] * (i+1);
  }
  max_n_ = nxtn;
}

#line 6 "test/atcoder/abc116-c-multiple_sequences-dirichlet-optimized.test.cpp"
#include <bits/stdc++.h>
using namespace std;
using mint = Modular998244353;

int main() {
  int n,m; cin>>n>>m;
  constexpr int max_logn = 30;
  auto primes = Primes(m);
  // Sum of zeta^k for m
  auto f = Identity<mint>(m);
  vector<mint> A(max_logn);
  for (int k = 0; k < max_logn; k++) {
    A[k] = f.second[1];
    f = DirichletConvolveZeta(m, primes, f);
  }
  // Sum of (zeta-1)^k for m
  vector<mint> B(max_logn);
  Enumerate<mint> enm;
  for (int k = 0; k < max_logn; k++) {
    for (int i = 0; i <= k; i++) {
      auto coef = enm.cmb(k, i);
      if ((k-i)%2) coef = -coef;
      B[k] += A[i] * coef;
    }
  }
  // ans = sum_k B[k] binom(n, k)
  mint binom = 1;
  mint ans = 0;
  for (int k = 0; k < max_logn; k++) {
    ans += binom * B[k];
    binom *= mint(n-k) * mint(k+1).inv();
  }
  cout << ans << endl;
}
Back to top page