#include <bits/stdc++.h>
using namespace std;
#define int long long
using poly = vector<int>;
const int N = 4e6 + 10, P = 998244353, G = 3;
int add(int a, int b) {
if (a + b >= P) return a + b - P;
return a + b;
}
int sub(int a, int b) {
if (a - b < 0) return a - b + P;
return a - b;
}
int mul(int a, int b) {
return a * b % P;
}
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = mul(res, a);
a = mul(a, a);
b >>= 1;
}
return res;
}
poly operator+(const poly& a, const poly& b);
poly operator-(const poly& a, const poly& b);
poly operator*(const poly& a, const poly& b);
poly operator/(const poly& a, const poly& b);
poly operator%(const poly& a, const poly& b);
poly operator~(const poly& a);
const int Gi = qpow(G, P - 2);
namespace Poly {
int f[N], g[N], h[N], lim, l, r[N];
void ntt(int a[], int type) {
for (int i = 0; i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]);
for (int x = 2; x <= lim; x <<= 1) {
int wn = qpow(type == 1 ? G : Gi, (P - 1) / x);
for (int i = 0; i < lim; i += x) {
int w = 1;
for (int j = 0; j < x / 2; j++, w = mul(w, wn)) {
int u = a[i + j], v = mul(a[i + j + x / 2], w);
a[i + j] = add(u, v);
a[i + j + x / 2] = sub(u, v);
}
}
}
}
void inv(int a[], int b[], int n) {
if (n == 1) {
b[0] = qpow(a[0], P - 2);
return;
}
int m = (n + 1) >> 1;
inv(a, b, m);
lim = 1, l = 0;
while (lim < 2 * n) lim <<= 1, l++;
for (int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)), f[i] = 0;
for (int i = 0; i < n; i++) f[i] = a[i];
for (int i = m; i < lim; i++) b[i] = 0;
ntt(f, 1);
ntt(b, 1);
for (int i = 0; i < lim; i++) f[i] = sub(mul(b[i], 2), mul(f[i], mul(b[i], b[i])));
ntt(f, -1);
int t = qpow(lim, P - 2);
for (int i = 0; i < n; i++) b[i] = mul(f[i], t);
for (int i = n; i < lim; i++) b[i] = 0;
}
pair<poly, poly> div(const poly &a, const poly &b) {
int n = a.size() - 1, m = b.size() - 1;
static poly q, r, f, g;
f = a, g = b;
reverse(f.begin(), f.end());
reverse(g.begin(), g.end());
f.resize(n - m + 1);
g.resize(n - m + 1);
q = f * ~g;
q.resize(n - m + 1);
reverse(q.begin(), q.end());
r = a - q * b;
r.resize(m);
return make_pair(q, r);
}
}
poly operator+(const poly& a, const poly& b) {
static poly c;
int sz = max(a.size(), b.size());
c.resize(sz);
for (int i = 0; i < sz; i++) c[i] = 0;
for (int i = 0; i < a.size(); i++) c[i] = add(c[i], a[i]);
for (int i = 0; i < b.size(); i++) c[i] = add(c[i], b[i]);
return c;
}
poly operator-(const poly& a, const poly& b) {
static poly c;
int sz = max(a.size(), b.size());
c.resize(sz);
for (int i = 0; i < sz; i++) c[i] = 0;
for (int i = 0; i < a.size(); i++) c[i] = add(c[i], a[i]);
for (int i = 0; i < b.size(); i++) c[i] = sub(c[i], b[i]);
return c;
}
poly operator*(const poly& a, const poly& b) {
static poly c;
int sz = a.size() + b.size() - 1;
c.resize(sz);
for (int i = 0; i < sz; i++) c[i] = 0;
if (a.size() * b.size() <= 500) {
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < b.size(); j++) {
c[i + j] = add(c[i + j], mul(a[i], b[j]));
}
}
return c;
}
using namespace Poly;
lim = 1, l = 0;
while (lim <= sz) lim <<= 1, l++;
for (int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)), f[i] = g[i] = 0;
for (int i = 0; i < a.size(); i++) f[i] = a[i];
for (int i = 0; i < b.size(); i++) g[i] = b[i];
ntt(f, 1);
ntt(g, 1);
for (int i = 0; i < lim; i++) f[i] = mul(f[i], g[i]);
ntt(f, -1);
int t = qpow(lim, P - 2);
for (int i = 0; i < sz; i++) c[i] = mul(f[i], t);
return c;
}
poly operator~(const poly &a) {
static poly c;
int sz = a.size();
c.resize(sz);
using namespace Poly;
for (int i = 0; i < sz; i++) h[i] = a[i];
inv(h, g, sz);
for (int i = 0; i < sz; i++) c[i] = g[i];
return c;
}
poly operator/(const poly& a, const poly& b) {
return Poly::div(a, b).first;
}
poly operator%(const poly& a, const poly& b) {
return Poly::div(a, b).second;
}
本文作者:ZnPdCo
本文链接: https://znpdco.fun/2025/02/16/poly/
本页面的全部内容在 CC BY-SA 4.0 和 SATA 协议之条款下提供,附加条款亦可能应用
评论