2 条题解

  • 2
    @ 2021-11-19 14:55:41

    树剖的单点修改,路径查询的模板题

    #include <bits/stdc++.h>
    using namespace std;
    const int MAXN = 30000 + 5;
    const int INF = 0x3f3f3f3f;
    int n, q;
    vector<int> e[MAXN];
    int w[MAXN]; //每个点的权值
    //---------树剖基础---------
    //每个点的:父节点、深度、大小、重子节点
    int fa[MAXN], dep[MAXN], siz[MAXN], hson[MAXN];
    void dfs_build(int u, int fat)
    {
        hson[u] = 0;
        siz[hson[u]] = 0;
        siz[u] = 1;
        for (int i = 0; i < e[u].size(); i++)
        {
            int v = e[u][i];
            if (v == fat)
                continue;
            dep[v] = dep[u] + 1;
            fa[v] = u;
            dfs_build(v, u);
            siz[u] += siz[v];
            if (siz[v] > siz[hson[u]])
                hson[u] = v;
        }
    }
    //每个点的:所在链的链顶、重边优先的 dfs 序、dfs序对应的节点编号
    int tot, top[MAXN], dfn[MAXN], rnk[MAXN];
    void dfs_div(int u, int fa)
    {
        dfn[u] = ++tot;
        rnk[tot] = u;
        if (hson[u])
        {
            top[hson[u]] = top[u];
            dfs_div(hson[u], u);
            for (int i = 0; i < e[u].size(); i++)
            {
                int v = e[u][i];
                if (v == fa || v == hson[u])
                    continue;
                top[v] = v;
                dfs_div(v, u);
            }
        }
    }
    
    //---------线段树---------
    struct SegTree
    {
        int sum[MAXN * 4], maxx[MAXN * 4];
        void build(int o, int l, int r)
        {
            if (l == r)
            {
                sum[o] = maxx[o] = w[rnk[l]];
                return;
            }
            int mid = (l + r) >> 1;
            build(o * 2, l, mid);
            build(o * 2 + 1, mid + 1, r);
            sum[o] = sum[o * 2] + sum[o * 2 + 1];
            maxx[o] = max(maxx[o * 2], maxx[o * 2 + 1]);
        }
        int query_max(int o, int l, int r, int ql, int qr)
        {
            if (l > qr || r < ql)
                return -INF;
            if (ql <= l && r <= qr)
                return maxx[o];
            int mid = (l + r) >> 1;
            return max(query_max(o * 2, l, mid, ql, qr), query_max(o * 2 + 1, mid + 1, r, ql, qr));
        }
        int query_sum(int o, int l, int r, int ql, int qr)
        {
            if (l > qr || r < ql)
                return 0;
            if (ql <= l && r <= qr)
                return sum[o];
            int mid = (l + r) >> 1;
            return query_sum(o * 2, l, mid, ql, qr) + query_sum(o * 2 + 1, mid + 1, r, ql, qr);
        }
        void update(int o, int l, int r, int x, int t)
        {
            if (l == r)
            {
                maxx[o] = sum[o] = t;
                return;
            }
            int mid = (l + r) >> 1;
            if (x <= mid)
                update(o * 2, l, mid, x, t); 
            else
                update(o * 2 + 1, mid + 1, r, x, t);
            sum[o] = sum[o * 2] + sum[o * 2 + 1];
            maxx[o] = max(maxx[o * 2], maxx[o * 2 + 1]);
        }
    } st;
    //---------树剖路径最值---------
    int query_max(int x, int y)
    {
        int ret = -INF, fx = top[x], fy = top[y];
        while (fx != fy)
        {
            if (dep[fx] >= dep[fy])
                ret = max(ret, st.query_max(1, 1, n, dfn[fx], dfn[x])), x = fa[fx];
            else
                ret = max(ret, st.query_max(1, 1, n, dfn[fy], dfn[y])), y = fa[fy];
            fx = top[x];
            fy = top[y];
        }
        if (dfn[x] < dfn[y])
            ret = max(ret, st.query_max(1, 1, n, dfn[x], dfn[y]));
        else
            ret = max(ret, st.query_max(1, 1, n, dfn[y], dfn[x]));
        return ret;
    }
    //---------树剖路径和---------
    int query_sum(int x, int y)
    {
        int ret = 0, fx = top[x], fy = top[y];
        while (fx != fy)
        {
            if (dep[fx] >= dep[fy])
                ret += st.query_sum(1, 1, n, dfn[fx], dfn[x]), x = fa[fx];
            else
                ret += st.query_sum(1, 1, n, dfn[fy], dfn[y]), y = fa[fy];
            fx = top[x];
            fy = top[y];
        }
        if (dfn[x] < dfn[y])
            ret += st.query_sum(1, 1, n, dfn[x], dfn[y]);
        else
            ret += st.query_sum(1, 1, n, dfn[y], dfn[x]);
        return ret;
    }
    int main()
    {
        ios::sync_with_stdio(false);
        cin.tie(0);
        cin >> n;
        for (int i = 1; i <= n - 1; i++)
        {
            int u, v;
            cin >> u >> v;
            e[u].push_back(v);
            e[v].push_back(u);
        }
        for (int i = 1; i <= n; i++)
            cin >> w[i];
        dep[1] = 1;
        fa[1] = 0;
        dfs_build(1, 0);
        tot = 0;
        top[1] = 1;
        dfs_div(1, 0);
        st.build(1, 1, n);
        cin >> q;
        while (q--)
        {
            string op;
            int u, v, t;
            cin >> op;
            if (op == "CHANGE")
            {
                cin >> u >> t;
                st.update(1, 1, n, dfn[u], t);
            }
            if (op == "QMAX")
            {
                cin >> u >> v;
                cout << query_max(u, v) << "\n";
            }
            else if (op == "QSUM")
            {
                cin >> u >> v;
                cout << query_sum(u, v) << "\n";
            }
        }
        return 0;
    }
    
    • 1
      @ 2023-11-1 21:49:04

      树剖模板题,不多做解释。

      #include<bits/stdc++.h>
      using namespace std;
      const int N=3e4+5;
      const int inf=0x3f3f3f3f;
      struct edge{
      	int x,y,pre;
      }a[2*N];
      int last[N],alen;
      void ins(int x,int y){
      	a[++alen]=edge{x,y,last[x]};
      	last[x]=alen;
      }
      int n,m,w[N];
      struct tnode{
      	int fa,dep,son,siz,top,id;
      }t[N];
      void dfs1(int x,int fa){
      	t[x]=tnode{fa,t[fa].dep+1,0,1,0,0};
      	for(int k=last[x];k;k=a[k].pre){
      		int y=a[k].y;
      		if(y==fa)continue;
      		dfs1(y,x);
      		t[x].siz+=t[y].siz;
      		if(t[t[x].son].siz<t[y].siz){
      			t[x].son=y;
      		}
      	}
      }
      int cnt,pos[N];
      void dfs2(int x,int top){
      	t[x].top=top;
      	t[x].id=++cnt;
      	pos[cnt]=x;
      	if(t[x].son)dfs2(t[x].son,top);
      	for(int k=last[x];k;k=a[k].pre){
      		int y=a[k].y;
      		if(y!=t[x].fa&&y!=t[x].son){
      			dfs2(y,y);
      		}
      	}
      }
      struct trnode{
      	int l,r,lc,rc;
      	int mx,sum;
      }tr[2*N];
      int trlen;
      void pushup(int now,int lc,int rc){
      	tr[now].mx=max(tr[lc].mx,tr[rc].mx);
      	tr[now].sum=tr[lc].sum+tr[rc].sum;
      }
      void build(int nl,int nr){
      	trlen++;int now=trlen;
      	tr[now]=trnode{nl,nr,-1,-1,-inf,0};
      	if(nl==nr){
      		tr[now].mx=w[pos[nl]];
      		tr[now].sum=w[pos[nl]];
      	}
      	else{
      		int mid=(nl+nr)>>1;
      		tr[now].lc=trlen+1;build(nl,mid);
      		tr[now].rc=trlen+1;build(mid+1,nr);
      		pushup(now,tr[now].lc,tr[now].rc);
      	}
      }
      void change(int now,int nl,int nr,int x,int c){
      	if(nl==nr){
      		tr[now].mx=c;
      		tr[now].sum=c;
      		return;
      	}
      	int mid=(nl+nr)>>1;
      	int lc=tr[now].lc,rc=tr[now].rc;
      	if(x<=mid)change(lc,nl,mid,x,c);
      	else change(rc,mid+1,nr,x,c);
      	pushup(now,lc,rc);
      }
      int qmax(int now,int nl,int nr,int l,int r){
      	if(l<=nl&&nr<=r){
      		return tr[now].mx;
      	}
      	int mid=(nl+nr)>>1;
      	int lc=tr[now].lc,rc=tr[now].rc;
      	int res=-inf;
      	if(l<=mid)res=max(res,qmax(lc,nl,mid,l,r));
      	if(mid<r)res=max(res,qmax(rc,mid+1,nr,l,r));
      	return res;
      }
      int query(int now,int nl,int nr,int l,int r){
      	if(l<=nl&&nr<=r){
      		return tr[now].sum;
      	}
      	int mid=(nl+nr)>>1;
      	int lc=tr[now].lc,rc=tr[now].rc;
      	int res=0;
      	if(l<=mid)res+=query(lc,nl,mid,l,r);
      	if(mid<r)res+=query(rc,mid+1,nr,l,r);
      	return res;
      }
      int solve(int x,int y){
      	int ans=-inf;
      	while(t[x].top!=t[y].top){
      		if(t[t[x].top].dep<t[t[y].top].dep)swap(x,y);
      		ans=max(ans,qmax(1,1,n,t[t[x].top].id,t[x].id));
      		x=t[t[x].top].fa;
      	}
      	if(t[x].dep>t[y].dep)swap(x,y);
      	ans=max(ans,qmax(1,1,n,t[x].id,t[y].id));
      	return ans;
      }
      int work(int x,int y){
      	int ans=0;
      	while(t[x].top!=t[y].top){
      		if(t[t[x].top].dep<t[t[y].top].dep)swap(x,y);
      		ans+=query(1,1,n,t[t[x].top].id,t[x].id);
      		x=t[t[x].top].fa;
      	}
      	if(t[x].dep>t[y].dep)swap(x,y);
      	ans+=query(1,1,n,t[x].id,t[y].id);
      	return ans;
      }
      int main(){
      	scanf("%d",&n);
      	alen=1;memset(last,0,sizeof(last));
      	for(int i=1;i<n;i++){
      		int x,y;scanf("%d%d",&x,&y);
      		ins(x,y),ins(y,x);
      	}
      	for(int i=1;i<=n;i++){
      		scanf("%d",&w[i]);
      	}
      	t[0]=tnode{0,0,0,0,0,0};cnt=0;
      	dfs1(1,0),dfs2(1,1);
      	trlen=0;build(1,n);
      	scanf("%d",&m);
      	for(int i=1;i<=m;i++){
      		char op[10];int x,y,c;
      		scanf("%s",op+1);
      		if(op[2]=='H'){
      			scanf("%d%d",&x,&c);
      			change(1,1,n,t[x].id,c);
      		}
      		if(op[2]=='M'){
      			scanf("%d%d",&x,&y);
      			printf("%d\n",solve(x,y));
      		}
      		if(op[2]=='S'){
      			scanf("%d%d",&x,&y);
      			printf("%d\n",work(x,y));
      		}
      	}
      	return 0;
      }
      
      • 1

      信息

      ID
      1036
      时间
      1000ms
      内存
      162MiB
      难度
      3
      标签
      递交数
      64
      已通过
      33
      上传者