【2 级置顶】多项式模板

OI
文章目录
#include <bits/stdc++.h>
using namespace std;
#define int long long
using poly = vector<int>;
const int N = 4e6 + 10, P = 998244353, G = 3;
int add(int a, int b) {
    if (a + b >= P) return a + b - P;
    return a + b;
}
int sub(int a, int b) {
    if (a - b < 0) return a - b + P;
    return a - b;
}
int mul(int a, int b) {
    return a * b % P;
}
int qpow(int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1) res = mul(res, a);
        a = mul(a, a);
        b >>= 1;
    }
    return res;
}
poly operator+(const poly& a, const poly& b);
poly operator-(const poly& a, const poly& b);
poly operator*(const poly& a, const poly& b);
poly operator/(const poly& a, const poly& b);
poly operator%(const poly& a, const poly& b);
poly operator~(const poly& a);
const int Gi = qpow(G, P - 2);
namespace Poly {
    int f[N], g[N], h[N], lim, l, r[N];
    void ntt(int a[], int type) {
        for (int i = 0; i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]);
        for (int x = 2; x <= lim; x <<= 1) {
            int wn = qpow(type == 1 ? G : Gi, (P - 1) / x);
            for (int i = 0; i < lim; i += x) {
                int w = 1;
                for (int j = 0; j < x / 2; j++, w = mul(w, wn)) {
                    int u = a[i + j], v = mul(a[i + j + x / 2], w);
                    a[i + j] = add(u, v);
                    a[i + j + x / 2] = sub(u, v);
                }
            }
        }
    }
    void inv(int a[], int b[], int n) {
        if (n == 1) {
            b[0] = qpow(a[0], P - 2);
            return;
        }
        int m = (n + 1) >> 1;
        inv(a, b, m);
        lim = 1, l = 0;
        while (lim < 2 * n) lim <<= 1, l++;
        for (int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)), f[i] = 0;
        for (int i = 0; i < n; i++) f[i] = a[i];
        for (int i = m; i < lim; i++) b[i] = 0;
        ntt(f, 1);
        ntt(b, 1);
        for (int i = 0; i < lim; i++) f[i] = sub(mul(b[i], 2), mul(f[i], mul(b[i], b[i])));
        ntt(f, -1);
        int t = qpow(lim, P - 2);
        for (int i = 0; i < n; i++) b[i] = mul(f[i], t);
        for (int i = n; i < lim; i++) b[i] = 0;
    }
    pair<poly, poly> div(const poly &a, const poly &b) {
        int n = a.size() - 1, m = b.size() - 1;
        static poly q, r, f, g;
        f = a, g = b;
        reverse(f.begin(), f.end());
        reverse(g.begin(), g.end());
        f.resize(n - m + 1);
        g.resize(n - m + 1);
        q = f * ~g;
        q.resize(n - m + 1);
        reverse(q.begin(), q.end());
        r = a - q * b;
        r.resize(m);
        return make_pair(q, r);
    }
}
poly operator+(const poly& a, const poly& b) {
    static poly c;
    int sz = max(a.size(), b.size());
    c.resize(sz);
    for (int i = 0; i < sz; i++) c[i] = 0;
    for (int i = 0; i < a.size(); i++) c[i] = add(c[i], a[i]);
    for (int i = 0; i < b.size(); i++) c[i] = add(c[i], b[i]);
    return c;
}
poly operator-(const poly& a, const poly& b) {
    static poly c;
    int sz = max(a.size(), b.size());
    c.resize(sz);
    for (int i = 0; i < sz; i++) c[i] = 0;
    for (int i = 0; i < a.size(); i++) c[i] = add(c[i], a[i]);
    for (int i = 0; i < b.size(); i++) c[i] = sub(c[i], b[i]);
    return c;
}
poly operator*(const poly& a, const poly& b) {
    static poly c;
    int sz = a.size() + b.size() - 1;
    c.resize(sz);
    for (int i = 0; i < sz; i++) c[i] = 0;
    if (a.size() * b.size() <= 500) {
        for (int i = 0; i < a.size(); i++) {
            for (int j = 0; j < b.size(); j++) {
                c[i + j] = add(c[i + j], mul(a[i], b[j]));
            }
        }
        return c;
    }
    using namespace Poly;
    lim = 1, l = 0;
    while (lim <= sz) lim <<= 1, l++;
    for (int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)), f[i] = g[i] = 0;
    for (int i = 0; i < a.size(); i++) f[i] = a[i];
    for (int i = 0; i < b.size(); i++) g[i] = b[i];
    ntt(f, 1);
    ntt(g, 1);
    for (int i = 0; i < lim; i++) f[i] = mul(f[i], g[i]);
    ntt(f, -1);
    int t = qpow(lim, P - 2);
    for (int i = 0; i < sz; i++) c[i] = mul(f[i], t);
    return c;
}
poly operator~(const poly &a) {
    static poly c;
    int sz = a.size();
    c.resize(sz);
    using namespace Poly;
    for (int i = 0; i < sz; i++) h[i] = a[i];
    inv(h, g, sz);
    for (int i = 0; i < sz; i++) c[i] = g[i];
    return c;
}
poly operator/(const poly& a, const poly& b) {
    return Poly::div(a, b).first;
}
poly operator%(const poly& a, const poly& b) {
    return Poly::div(a, b).second;
}

本文作者:ZnPdCo

本文链接: https://znpdco.fun/2025/02/16/poly/

本页面的全部内容在 CC BY-SA 4.0SATA 协议之条款下提供,附加条款亦可能应用

评论