1 条题解

  • 0
    @ 2024-10-10 10:51:15

    线段树

    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 1e5 + 10, INF = 0x3f3f3f3f;
    int n, m, q, a[N];
    
    struct Node {
        int l, r;
        ll x, y, sum;
    } tr[N << 2];
    
    void pushup(Node& u, Node& l, Node& r) {
        u.sum = (l.sum + r.sum) % m;
    }
    void pushup(int u) {
        pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
    }
    void pushdown(int u) {
        if (tr[u].x != 1) {
            tr[u << 1].x = (tr[u << 1].x * tr[u].x) % m;
            tr[u << 1].y = (tr[u << 1].y * tr[u].x) % m;
            tr[u << 1].sum = (tr[u << 1].sum * tr[u].x) % m;
            tr[u << 1 | 1].x = (tr[u << 1 | 1].x * tr[u].x) % m;
            tr[u << 1 | 1].y = (tr[u << 1 | 1].y * tr[u].x) % m;
            tr[u << 1 | 1].sum = (tr[u << 1 | 1].sum * tr[u].x) % m;
            tr[u].x = 1;
        }
    
        if (tr[u].y) {
            tr[u << 1].y = (tr[u << 1].y + tr[u].y) % m;
            tr[u << 1].sum = (tr[u << 1].sum + tr[u].y * (tr[u << 1].r - tr[u << 1].l + 1) % m) % m;
            tr[u << 1 | 1].y = (tr[u << 1 | 1].y + tr[u].y) % m;
            tr[u << 1 | 1].sum = (tr[u << 1 | 1].sum + tr[u].y * (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) % m) % m;
            tr[u].y = 0;
        }
    }
    void build(int u, int l, int r) {
        tr[u] = {l, r, 1, 0, 0};
        if (l == r) {
            tr[u].sum = a[l];
            return;
        }
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
    void modify(int u, int l, int r, ll x, ll y) {
        if (l <= tr[u].l && r >= tr[u].r) {
            tr[u].x = (tr[u].x * x) % m;
            tr[u].y = (tr[u].y * x) % m;
            tr[u].sum = (tr[u].sum * x) % m;
            tr[u].y = (tr[u].y + y) % m;
            tr[u].sum = (tr[u].sum + y * (tr[u].r - tr[u].l + 1) % m) % m;
            return;
        }
    
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid) modify(u << 1, l, r, x, y);
        if (r > mid) modify(u << 1 | 1, l, r, x, y);
        pushup(u);
    }
    Node query(int u, int l, int r) {
        if (l <= tr[u].l && r >= tr[u].r) return tr[u];
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        Node res, left = {0, 0, 0, 0, 0}, right = {0, 0, 0, 0, 0};
        if (l <= mid) left = query(u << 1, l, r);
        if (r > mid) right = query(u << 1 | 1, l, r);
        pushup(res, left, right);
        return res;
    }
    
    int main() {
        scanf("%d%d", &n,&m);
        for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
        build(1, 1, n); int op, x, y, k; 
        scanf("%d",&q);
        while (q--) {
            scanf("%d%d%d", &op, &x, &y);
            if (op != 3) scanf("%d", &k);
            if (op == 1) modify(1, x, y, k, 0);
            if (op == 2) modify(1, x, y, 1, k);
            if (op == 3) printf("%lld\n", query(1, x, y).sum % m);
        }
        return 0;
    }
    
    • 1

    信息

    ID
    1676
    时间
    1000ms
    内存
    512MiB
    难度
    10
    标签
    递交数
    4
    已通过
    1
    上传者