1 条题解

  • 0
    @ 2025-3-30 17:29:31

    即,从 ANA_N 中有放回地选择 M\le M 个数,问它们异或起来不为 00 的方案数。

    如果令 fi,jf_{i, j} 表示选了 ii 次,异或和为 jj 的方案数,显然 f1,i=[aj=i]f_{1,i}=\sum [a_j=i] 为关于 aa 的桶。此时有 $f_{i,j}=\sum\limits_{k=1}^n f_{i-1,j\oplus a_k}=\sum\limits_{k=0}^V f_{i-1,j\oplus k}\cdot f_{1,k}$,发现把 f1f_1 这个桶在 ff 上做 NN 次 xor-FWT 就可以得到 fnf_n

    但如果直接卷 NN 次是 O(NVlogV)O(N\cdot V\log V) 的,不太美好,但我们看看我们实际上需要做什么:

    1. fif_i 的 FWT。
    2. 求初始桶 f1f_1 的 FWT。
    3. 对位相乘得到 fi+1f_{i+1} 的 FWT。
    4. 通过 FWT 求得原本的 fi+1f_{i+1}

    当这个操作被放在 i=1ni=1\sim n 上依次进行时,我们发现第一步和最后一步会相互抵消,我们只需要求出 f1f_1 的 FWT,FWTi,j(f)FWT_{i, j}(f) 即为 FWT1,j(f)iFWT_{1, j}(f)^i。因为我们要求的是 i,jfi,j\sum\limits_{i, j}f_{i,j} 可以通过等比数列求和求出 FWTj(s)=fi,jFWT_j(s)=\sum f_{i, j}。我们知道 FWT 的变换是线性的,在 FWT 上进行对位运算后可以做 IFWT 得到原数组上的对位运算结果。

    直接做一次逆变换求得 sjs_j 即可。

    #include <bits/stdc++.h>
    int main() {
    #ifdef ONLINE_JUDGE
        std::ios::sync_with_stdio(false);
        std::cin.tie(nullptr), std::cout.tie(nullptr);
    #else
        std::freopen(".in", "r", stdin);
        std::freopen(".out", "w", stdout);
    #endif
        struct mint {
            const int mod = 998244353;
            long long x;
            mint(): x(0ll) {}
            mint(long long x1): x((x1 + mod) % mod) {}
            mint& operator= (const mint q) {
                x = q.x;
                return *this;
            }
            bool operator== (const mint q) const {
                return x == q.x;
            }
            mint operator* (const mint q) const {
                return x * q.x % mod;
            }
            mint& operator*= (const mint q) {
                return *this = *this * q;
            }
            mint operator+ (const mint q) {
                return (x + q.x) % mod;
            }
            mint& operator+= (const mint q) {
                return *this = *this + q;
            }
            mint operator- (const mint q) {
                return (x + mod - q.x) % mod;
            }
            mint qkp(int y) {
                mint res(1ll), x(this->x);
                for (; y; y >>= 1, x *= x)
                    if (y & 1)
                        res *= x;
                return res;
            }
            mint inv(void) {
                return qkp(mod - 2);
            }
        };
        int n, m, k = 16, l = 1 << k;
        std::cin >> m >> n;
        using arr = std::vector<mint>;
        arr a(n + 1), c(l);
        for (int i = 1; i <= n; ++i)
            std::cin >> a[i].x, c[a[i].x] += 1;
        std::vector<arr> mT(2, arr(2)), mI(2, arr(2));
        mT[0][0] = 1ll, mT[0][1] = 1ll, mT[1][0] = 1ll, mT[1][1] = -1ll;
        mI[0][0] = mI[0][1] = mI[1][0] = mint(2ll).inv(), mI[1][1] = mint(-2ll).inv();
        auto calc = [&](arr a, arr &f, std::vector<arr> &w) {
            f = a;
            for (int len = 2; len <= l; len <<= 1)
                for (int i = 0; i < l; i += len)
                    for (int p = i, q = i + len / 2; q < i + len; ++p, ++q)
                        std::tie(f[p], f[q]) = std::make_tuple(f[p] * w[0][0] + f[q] * w[0][1], f[p] * w[1][0] + f[q] * w[1][1]);
            return;
        };
        calc(c, c, mT);
        arr s(l);
        for (int i = 0; i < l; ++i)
            if (c[i] == 1ll)
                s[i] = m;
            else
                s[i] = c[i] * (mint(1ll) - c[i].qkp(m)) * (mint(1ll) - c[i]).inv();
        calc(s, s, mI);
        mint res;
        for (int i = 1; i < l; ++i)
            res += s[i];
        std::cout << res.x << '\n';
        return 0;
    }
    
    • 1

    信息

    ID
    18497
    时间
    2000ms
    内存
    1024MiB
    难度
    6
    标签
    (无)
    递交数
    9
    已通过
    2
    上传者