2 条题解
-
2
树剖的单点修改,路径查询的模板题
#include <bits/stdc++.h> using namespace std; const int MAXN = 30000 + 5; const int INF = 0x3f3f3f3f; int n, q; vector<int> e[MAXN]; int w[MAXN]; //每个点的权值 //---------树剖基础--------- //每个点的:父节点、深度、大小、重子节点 int fa[MAXN], dep[MAXN], siz[MAXN], hson[MAXN]; void dfs_build(int u, int fat) { hson[u] = 0; siz[hson[u]] = 0; siz[u] = 1; for (int i = 0; i < e[u].size(); i++) { int v = e[u][i]; if (v == fat) continue; dep[v] = dep[u] + 1; fa[v] = u; dfs_build(v, u); siz[u] += siz[v]; if (siz[v] > siz[hson[u]]) hson[u] = v; } } //每个点的:所在链的链顶、重边优先的 dfs 序、dfs序对应的节点编号 int tot, top[MAXN], dfn[MAXN], rnk[MAXN]; void dfs_div(int u, int fa) { dfn[u] = ++tot; rnk[tot] = u; if (hson[u]) { top[hson[u]] = top[u]; dfs_div(hson[u], u); for (int i = 0; i < e[u].size(); i++) { int v = e[u][i]; if (v == fa || v == hson[u]) continue; top[v] = v; dfs_div(v, u); } } } //---------线段树--------- struct SegTree { int sum[MAXN * 4], maxx[MAXN * 4]; void build(int o, int l, int r) { if (l == r) { sum[o] = maxx[o] = w[rnk[l]]; return; } int mid = (l + r) >> 1; build(o * 2, l, mid); build(o * 2 + 1, mid + 1, r); sum[o] = sum[o * 2] + sum[o * 2 + 1]; maxx[o] = max(maxx[o * 2], maxx[o * 2 + 1]); } int query_max(int o, int l, int r, int ql, int qr) { if (l > qr || r < ql) return -INF; if (ql <= l && r <= qr) return maxx[o]; int mid = (l + r) >> 1; return max(query_max(o * 2, l, mid, ql, qr), query_max(o * 2 + 1, mid + 1, r, ql, qr)); } int query_sum(int o, int l, int r, int ql, int qr) { if (l > qr || r < ql) return 0; if (ql <= l && r <= qr) return sum[o]; int mid = (l + r) >> 1; return query_sum(o * 2, l, mid, ql, qr) + query_sum(o * 2 + 1, mid + 1, r, ql, qr); } void update(int o, int l, int r, int x, int t) { if (l == r) { maxx[o] = sum[o] = t; return; } int mid = (l + r) >> 1; if (x <= mid) update(o * 2, l, mid, x, t); else update(o * 2 + 1, mid + 1, r, x, t); sum[o] = sum[o * 2] + sum[o * 2 + 1]; maxx[o] = max(maxx[o * 2], maxx[o * 2 + 1]); } } st; //---------树剖路径最值--------- int query_max(int x, int y) { int ret = -INF, fx = top[x], fy = top[y]; while (fx != fy) { if (dep[fx] >= dep[fy]) ret = max(ret, st.query_max(1, 1, n, dfn[fx], dfn[x])), x = fa[fx]; else ret = max(ret, st.query_max(1, 1, n, dfn[fy], dfn[y])), y = fa[fy]; fx = top[x]; fy = top[y]; } if (dfn[x] < dfn[y]) ret = max(ret, st.query_max(1, 1, n, dfn[x], dfn[y])); else ret = max(ret, st.query_max(1, 1, n, dfn[y], dfn[x])); return ret; } //---------树剖路径和--------- int query_sum(int x, int y) { int ret = 0, fx = top[x], fy = top[y]; while (fx != fy) { if (dep[fx] >= dep[fy]) ret += st.query_sum(1, 1, n, dfn[fx], dfn[x]), x = fa[fx]; else ret += st.query_sum(1, 1, n, dfn[fy], dfn[y]), y = fa[fy]; fx = top[x]; fy = top[y]; } if (dfn[x] < dfn[y]) ret += st.query_sum(1, 1, n, dfn[x], dfn[y]); else ret += st.query_sum(1, 1, n, dfn[y], dfn[x]); return ret; } int main() { ios::sync_with_stdio(false); cin.tie(0); cin >> n; for (int i = 1; i <= n - 1; i++) { int u, v; cin >> u >> v; e[u].push_back(v); e[v].push_back(u); } for (int i = 1; i <= n; i++) cin >> w[i]; dep[1] = 1; fa[1] = 0; dfs_build(1, 0); tot = 0; top[1] = 1; dfs_div(1, 0); st.build(1, 1, n); cin >> q; while (q--) { string op; int u, v, t; cin >> op; if (op == "CHANGE") { cin >> u >> t; st.update(1, 1, n, dfn[u], t); } if (op == "QMAX") { cin >> u >> v; cout << query_max(u, v) << "\n"; } else if (op == "QSUM") { cin >> u >> v; cout << query_sum(u, v) << "\n"; } } return 0; }
-
1
树剖模板题,不多做解释。
#include<bits/stdc++.h> using namespace std; const int N=3e4+5; const int inf=0x3f3f3f3f; struct edge{ int x,y,pre; }a[2*N]; int last[N],alen; void ins(int x,int y){ a[++alen]=edge{x,y,last[x]}; last[x]=alen; } int n,m,w[N]; struct tnode{ int fa,dep,son,siz,top,id; }t[N]; void dfs1(int x,int fa){ t[x]=tnode{fa,t[fa].dep+1,0,1,0,0}; for(int k=last[x];k;k=a[k].pre){ int y=a[k].y; if(y==fa)continue; dfs1(y,x); t[x].siz+=t[y].siz; if(t[t[x].son].siz<t[y].siz){ t[x].son=y; } } } int cnt,pos[N]; void dfs2(int x,int top){ t[x].top=top; t[x].id=++cnt; pos[cnt]=x; if(t[x].son)dfs2(t[x].son,top); for(int k=last[x];k;k=a[k].pre){ int y=a[k].y; if(y!=t[x].fa&&y!=t[x].son){ dfs2(y,y); } } } struct trnode{ int l,r,lc,rc; int mx,sum; }tr[2*N]; int trlen; void pushup(int now,int lc,int rc){ tr[now].mx=max(tr[lc].mx,tr[rc].mx); tr[now].sum=tr[lc].sum+tr[rc].sum; } void build(int nl,int nr){ trlen++;int now=trlen; tr[now]=trnode{nl,nr,-1,-1,-inf,0}; if(nl==nr){ tr[now].mx=w[pos[nl]]; tr[now].sum=w[pos[nl]]; } else{ int mid=(nl+nr)>>1; tr[now].lc=trlen+1;build(nl,mid); tr[now].rc=trlen+1;build(mid+1,nr); pushup(now,tr[now].lc,tr[now].rc); } } void change(int now,int nl,int nr,int x,int c){ if(nl==nr){ tr[now].mx=c; tr[now].sum=c; return; } int mid=(nl+nr)>>1; int lc=tr[now].lc,rc=tr[now].rc; if(x<=mid)change(lc,nl,mid,x,c); else change(rc,mid+1,nr,x,c); pushup(now,lc,rc); } int qmax(int now,int nl,int nr,int l,int r){ if(l<=nl&&nr<=r){ return tr[now].mx; } int mid=(nl+nr)>>1; int lc=tr[now].lc,rc=tr[now].rc; int res=-inf; if(l<=mid)res=max(res,qmax(lc,nl,mid,l,r)); if(mid<r)res=max(res,qmax(rc,mid+1,nr,l,r)); return res; } int query(int now,int nl,int nr,int l,int r){ if(l<=nl&&nr<=r){ return tr[now].sum; } int mid=(nl+nr)>>1; int lc=tr[now].lc,rc=tr[now].rc; int res=0; if(l<=mid)res+=query(lc,nl,mid,l,r); if(mid<r)res+=query(rc,mid+1,nr,l,r); return res; } int solve(int x,int y){ int ans=-inf; while(t[x].top!=t[y].top){ if(t[t[x].top].dep<t[t[y].top].dep)swap(x,y); ans=max(ans,qmax(1,1,n,t[t[x].top].id,t[x].id)); x=t[t[x].top].fa; } if(t[x].dep>t[y].dep)swap(x,y); ans=max(ans,qmax(1,1,n,t[x].id,t[y].id)); return ans; } int work(int x,int y){ int ans=0; while(t[x].top!=t[y].top){ if(t[t[x].top].dep<t[t[y].top].dep)swap(x,y); ans+=query(1,1,n,t[t[x].top].id,t[x].id); x=t[t[x].top].fa; } if(t[x].dep>t[y].dep)swap(x,y); ans+=query(1,1,n,t[x].id,t[y].id); return ans; } int main(){ scanf("%d",&n); alen=1;memset(last,0,sizeof(last)); for(int i=1;i<n;i++){ int x,y;scanf("%d%d",&x,&y); ins(x,y),ins(y,x); } for(int i=1;i<=n;i++){ scanf("%d",&w[i]); } t[0]=tnode{0,0,0,0,0,0};cnt=0; dfs1(1,0),dfs2(1,1); trlen=0;build(1,n); scanf("%d",&m); for(int i=1;i<=m;i++){ char op[10];int x,y,c; scanf("%s",op+1); if(op[2]=='H'){ scanf("%d%d",&x,&c); change(1,1,n,t[x].id,c); } if(op[2]=='M'){ scanf("%d%d",&x,&y); printf("%d\n",solve(x,y)); } if(op[2]=='S'){ scanf("%d%d",&x,&y); printf("%d\n",work(x,y)); } } return 0; }
- 1
信息
- ID
- 1036
- 时间
- 1000ms
- 内存
- 162MiB
- 难度
- 3
- 标签
- 递交数
- 64
- 已通过
- 33
- 上传者