1 条题解

  • 1
    @ 2022-2-22 19:32:26

    在我的个人博客中阅读

    CF1626F A Random Code Problem

    题目大意

    给出一个长度为 nn 的序列,你需要进行 kk 次操作,第 ii 次操作将会任意选择一个序列元素,将它的值加入到答案后将该数减去其模 ii 的值。求 kk 次操作后你的答案的期望。

    n107, k17, ai<998244353n \le 10^7,~k \le 17,~a_i < 998244353

    分析

    我们发现若第 x1, x2,, xnx_1,~x_2,\dots,~x_n 次操作均在某一值为 vv 的元素上进行,该元素最终值为 vvmodlcmi xiv - v \bmod \mathrm{lcm} _i ~ x_i。由于操作数量不超过 1717,容易发现数组内的每一个元素在结束后都一定不小于 vvmodlcm1i17 iv - v \bmod \mathrm{lcm} _{1 \le i \le 17} ~ i

    我们令 L=lcm1ik iL = \mathrm{lcm}_{1 \le i \le k}~i。则我们可以将每个元素 vv 分为两部分,第一部分为 vvmodLv - v \bmod L,第二部分为 v%Lv \% L。考虑对这两部分分开进行计算。

    第一部分由于在经过任意操作后该数的第一部分均不会发生变化,因此第一部分对答案的贡献即为该数被操作到的期望数,即 $(v - v \bmod L) \times \frac {k \times n^{k-1}} {n^k}$。

    第二部分值小于 LL,由于值域较小而数组元素数量较多,考虑使用桶进行计数。不难计算出 f[i][j]f[i][j] 表示所有使用了 ii 次操作的方案(共 nin^i 种)中的值为 jj 的元素数量之和,通过对每一个状态统计下次操作若操作到该值时该状态对总答案的贡献即可求出答案: $\frac 1 {n^k} \times \sum_i \sum_j j \times f[i][j] \times n^{k-i-1}$。

    k=17k=17LL 可能过大以至于 DP 时间复杂度不可接受。容易发现第 kk 次操作的修改操作对总答案无影响,因此令 L=lcm1i<k iL = \mathrm{lcm}_{1 \le i < k}~i 即可。

    总时间复杂度 O(n+k×lcm1i<k i)O(n + k \times \mathrm{lcm}_{1 \le i < k}~i)

    代码

    View on GitHub

    Code
    /**
     * @file 1626F.cpp
     * @author Macesuted (i@macesuted.moe)
     * @date 2022-02-01
     *
     * @copyright Copyright (c) 2022
     * @brief
     *      My tutorial: https://macesuted.moe/article/cf1626f
     */
    
    #include <bits/stdc++.h>
    using namespace std;
    
    #define MP make_pair
    #define MT make_tuple
    
    namespace io {
    #define SIZE (1 << 20)
    char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55];
    int f, qr;
    inline void flush(void) { return fwrite(obuf, 1, oS - obuf, stdout), oS = obuf, void(); }
    inline char getch(void) {
        return (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS++)) : *iS++);
    }
    void putch(char x) {
        *oS++ = x;
        if (oS == oT) flush();
        return;
    }
    string getstr(void) {
        string s = "";
        char c = getch();
        while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = getch();
        while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF)) s.push_back(c), c = getch();
        return s;
    }
    void putstr(string str, int begin = 0, int end = -1) {
        if (end == -1) end = str.size();
        for (int i = begin; i < end; i++) putch(str[i]);
        return;
    }
    template <typename T>
    T read() {
        T x = 0;
        for (f = 1, c = getch(); c < '0' || c > '9'; c = getch())
            if (c == '-') f = -1;
        for (x = 0; c <= '9' && c >= '0'; c = getch()) x = x * 10 + (c & 15);
        return x * f;
    }
    template <typename T>
    void write(const T& t) {
        T x = t;
        if (!x) putch('0');
        if (x < 0) putch('-'), x = -x;
        while (x) qu[++qr] = x % 10 + '0', x /= 10;
        while (qr) putch(qu[qr--]);
        return;
    }
    struct Flusher_ {
        ~Flusher_() { flush(); }
    } io_flusher_;
    }  // namespace io
    using io::getch;
    using io::getstr;
    using io::putch;
    using io::putstr;
    using io::read;
    using io::write;
    
    bool mem1;
    
    #define maxk 18
    #define maxn 10000005
    #define maxL 720725
    #define mod 998244353
    
    int a[maxn];
    long long f[maxk][maxL];
    
    long long Pow(long long a, long long x) {
        long long ans = 1;
        while (x) {
            if (x & 1) ans = ans * a % mod;
            a = a * a % mod, x >>= 1;
        }
        return ans;
    }
    
    void solve(void) {
        int n = read<int>(), a0 = read<int>(), x = read<int>(), y = read<int>(), k = read<int>(), M = read<int>();
        a[1] = a0;
        for (int i = 2; i <= n; i++) a[i] = (1LL * a[i - 1] * x % M + y) % M;
        int L = 720720;
        long long ans = 0, coeff = k * Pow(n, k - 1) % mod;
        for (int i = 1; i <= n; i++) {
            int rest = a[i] % L;
            f[0][rest]++;
            ans = (ans + 1LL * (a[i] - rest) * coeff) % mod;
        }
        for (int i = 0; i < k; i++) {
            long long t = Pow(n, k - i - 1);
            for (int j = 0; j < L; j++)
                if (f[i][j]) {
                    f[i + 1][j] = (f[i + 1][j] + 1LL * f[i][j] * (n - 1) % mod) % mod;
                    f[i + 1][j - j % (i + 1)] = (f[i + 1][j - j % (i + 1)] + f[i][j]) % mod;
                    ans = (ans + 1LL * j * f[i][j] % mod * t) % mod;
                }
        }
        cout << ans << endl;
        return;
    }
    
    bool mem2;
    
    int main() {
    #ifdef MACESUTED
        cerr << "Memory: " << abs(&mem1 - &mem2) / 1024. / 1024. << "MB" << endl;
    #endif
    
        int _ = 1;
        while (_--) solve();
    
    #ifdef MACESUTED
        cerr << "Time: " << clock() * 1000. / CLOCKS_PER_SEC << "ms" << endl;
    #endif
        return 0;
    }
    
    • 1

    信息

    ID
    7612
    时间
    3000ms
    内存
    512MiB
    难度
    10
    标签
    递交数
    2
    已通过
    2
    上传者