2 条题解
-
5
#include<bits/stdc++.h> #define ll long long #define pb push_back #define mp make_pair #define pii pair<int,int> #define x first #define y second #define WT int TT=read();while(TT--) #define NO puts("NO"); #define YES puts("YES"); using namespace std; inline int read() { char c=getchar();int x=0;bool f=0; for(;!isdigit(c);c=getchar())f^=!(c^45); for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48); if(f)x=-x;return x; } void ckmax(int &a,int b){a=(a>b?a:b);} void ckmin(int &a,int b){a=(a<b?a:b);} const int Mod=998244353; const int M=4e5+10; int n,a[M],p[M],rt[M],f[M],m,ans,c[M],inv[M]; vector<int>e[M]; int poww(int a,int b=Mod-2) { int res=1; while(b) { if (b&1)res=1ll*res*a%Mod; a=1ll*a*a%Mod,b>>=1; }return res; } struct tree { signed ls,rs,v,s,tag1,tag2; }t[M*100];int cnt; int F(int x){return (x>=Mod)?x-Mod:x;} int newnode(int L,int R){t[++cnt]=(tree){0,0,0,p[R]-p[L-1],0,1};return cnt;} void mul(int k,int v){t[k].tag2=1ll*t[k].tag2*v%Mod,t[k].tag1=1ll*t[k].tag1*v%Mod,t[k].v=1ll*t[k].v*v%Mod;} void add(int k,int v){t[k].tag1=F(t[k].tag1+v),t[k].v=(t[k].v+1ll*v*t[k].s)%Mod;} void pushup(int k){t[k].v=F(t[t[k].ls].v+t[t[k].rs].v);} void pushdown(int k,int L,int R) { if (t[k].tag1==0&&t[k].tag2==1)return; int Mid=(L+R)>>1; if (!t[k].ls)t[k].ls=newnode(L,Mid); if (!t[k].rs)t[k].rs=newnode(Mid+1,R); if (t[k].tag2!=1)mul(t[k].ls,t[k].tag2),mul(t[k].rs,t[k].tag2),t[k].tag2=1; if (t[k].tag1!=0)add(t[k].ls,t[k].tag1),add(t[k].rs,t[k].tag1),t[k].tag1=0; } void update(int &k,int L,int R,int l,int r,int v1,int v2) { if (l>r)return; if (L>r||R<l)return; if (!k)k=newnode(L,R); if (L>=l&&R<=r) { if (v2==0)return k=0,void(); if (v1)add(k,v1); if (v2!=1)mul(k,v2); return; } pushdown(k,L,R); int Mid=(L+R)>>1; update(t[k].ls,L,Mid,l,r,v1,v2); update(t[k].rs,Mid+1,R,l,r,v1,v2); pushup(k); } int query(int k,int L,int R,int pos) { if (!k)return 0; if (L==R)return t[k].v; pushdown(k,L,R); int Mid=(L+R)>>1; if (pos<=Mid)return query(t[k].ls,L,Mid,pos); else return query(t[k].rs,Mid+1,R,pos); } int merge(int u,int v,int t1,int t2,int t3,int L,int R) { if (!u||!v) { int x=u+v==0?newnode(L,R):u+v; if (u==x)mul(u,t1); if (v==x)mul(v,t2); add(x,t3); return x; } mul(u,t1),mul(v,t2),add(u,t3); int Mid=(L+R)>>1; t[u].v=F(t[u].v+t[v].v); if (L==R)return u; t[u].ls=merge(t[u].ls,t[v].ls,t[u].tag2,t[v].tag2,F(t[u].tag1+t[v].tag1),L,Mid); t[u].rs=merge(t[u].rs,t[v].rs,t[u].tag2,t[v].tag2,F(t[u].tag1+t[v].tag1),Mid+1,R); t[u].tag2=t[v].tag2=1,t[u].tag1=t[v].tag1=0; return u; } void split(int x,int &y,int L,int R,int pos) { if (pos==R||!x)return; if (!y)y=newnode(L,R); pushdown(x,L,R); int Mid=(L+R)>>1; if (pos<=Mid)split(t[x].ls,t[y].ls,L,Mid,pos),swap(t[x].rs,t[y].rs); else split(t[x].rs,t[y].rs,Mid+1,R,pos); pushup(x),pushup(y); } void dfs(int u,int fa=0) { rt[u]=0,f[u]=a[u]; int tt=0; vector<int>V; for (auto v:e[u]) if (v!=fa) dfs(v,u),ckmin(f[u],f[v]),tt++,V.pb(f[v]); for (int i=1;i<=tt;i++)c[i]=V[i-1]; sort(c+1,c+1+tt); int tmp1=0,tmp2=0; for (auto v:e[u]) if (v!=fa) { tmp2=0;split(rt[v],tmp2,1,m,f[v]); rt[u]=merge(rt[u],rt[v],1,1,0,1,m); tmp1=merge(tmp1,tmp2,1,1,0,1,m); } c[++tt]=m; for (int i=1;i<tt;i++) if (c[i]!=c[i+1]) update(rt[u],1,m,c[i]+1,c[i+1],0,inv[i+1]), update(tmp1,1,m,c[i]+1,c[i+1],0,inv[i]); rt[u]=merge(rt[u],tmp1,1,1,0,1,m); update(rt[u],1,m,1,a[u],1,1); update(rt[u],1,m,a[u]+1,m,0,0); // cout<<u<<' '<<t[rt[u]].v<<'\n'; ans=F(ans+t[rt[u]].v); } void solve() { n=read();int rt=read(); ans=0;cnt=0; for (int i=1;i<=n;i++)a[i]=p[i]=read(),inv[i]=poww(i),e[i].clear(); for (int i=1;i<n;i++) { int u=read(),v=read(); e[u].pb(v),e[v].pb(u); } sort(p+1,p+1+n); m=unique(p+1,p+1+n)-p-1; for (int i=1;i<=n;i++)a[i]=lower_bound(p+1,p+1+m,a[i])-p; // puts("-----"); dfs(rt); cout<<ans<<'\n'; } signed main() { WT solve(); return 0; }
- 1
信息
- ID
- 258
- 时间
- 5000ms
- 内存
- 1024MiB
- 难度
- 9
- 标签
- 递交数
- 19
- 已通过
- 4
- 上传者