2 条题解

  • 5
    @ 2023-1-14 10:17:00
    #include<bits/stdc++.h>
    #define ll long long
    #define pb push_back
    #define mp make_pair
    #define pii pair<int,int>
    #define x first
    #define y second
    #define WT int TT=read();while(TT--) 
    #define NO puts("NO");
    #define YES puts("YES");
    using namespace std;
    
    inline int read()
    {
    	char c=getchar();int x=0;bool f=0;
    	for(;!isdigit(c);c=getchar())f^=!(c^45);
    	for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48);
    	if(f)x=-x;return x;
    }
    void ckmax(int &a,int b){a=(a>b?a:b);}
    void ckmin(int &a,int b){a=(a<b?a:b);}
    
    const int Mod=998244353;
    const int M=4e5+10;
    int n,a[M],p[M],rt[M],f[M],m,ans,c[M],inv[M];
    vector<int>e[M];
    
    int poww(int a,int b=Mod-2)
    {
    	int res=1;
    	while(b)
    	{
    		if (b&1)res=1ll*res*a%Mod;
    		a=1ll*a*a%Mod,b>>=1;
    	}return res;
    }
    
    struct tree
    {
    	signed ls,rs,v,s,tag1,tag2;
    }t[M*100];int cnt;
    
    int F(int x){return (x>=Mod)?x-Mod:x;}
    
    int newnode(int L,int R){t[++cnt]=(tree){0,0,0,p[R]-p[L-1],0,1};return cnt;}
    
    void mul(int k,int v){t[k].tag2=1ll*t[k].tag2*v%Mod,t[k].tag1=1ll*t[k].tag1*v%Mod,t[k].v=1ll*t[k].v*v%Mod;}
    void add(int k,int v){t[k].tag1=F(t[k].tag1+v),t[k].v=(t[k].v+1ll*v*t[k].s)%Mod;}
    
    void pushup(int k){t[k].v=F(t[t[k].ls].v+t[t[k].rs].v);}
    
    void pushdown(int k,int L,int R)
    {
    	if (t[k].tag1==0&&t[k].tag2==1)return;
    	int Mid=(L+R)>>1; 
    	if (!t[k].ls)t[k].ls=newnode(L,Mid);
    	if (!t[k].rs)t[k].rs=newnode(Mid+1,R);
    	if (t[k].tag2!=1)mul(t[k].ls,t[k].tag2),mul(t[k].rs,t[k].tag2),t[k].tag2=1;
    	if (t[k].tag1!=0)add(t[k].ls,t[k].tag1),add(t[k].rs,t[k].tag1),t[k].tag1=0;
    }
    
    void update(int &k,int L,int R,int l,int r,int v1,int v2)
    {
    	if (l>r)return;
    	if (L>r||R<l)return;
    	if (!k)k=newnode(L,R);
    	if (L>=l&&R<=r)
    	{
    		if (v2==0)return k=0,void();
    		if (v1)add(k,v1);
    		if (v2!=1)mul(k,v2);
    		return;
    	}
    	pushdown(k,L,R);
    	int Mid=(L+R)>>1;
    	update(t[k].ls,L,Mid,l,r,v1,v2);
    	update(t[k].rs,Mid+1,R,l,r,v1,v2);
    	pushup(k);
    }
    
    int query(int k,int L,int R,int pos)
    {
    	if (!k)return 0;
    	if (L==R)return t[k].v;
    	pushdown(k,L,R);
    	int Mid=(L+R)>>1;
    	if (pos<=Mid)return query(t[k].ls,L,Mid,pos);
    	else return query(t[k].rs,Mid+1,R,pos);
    }
    
    int merge(int u,int v,int t1,int t2,int t3,int L,int R)
    {
    	if (!u||!v)
    	{
    		int x=u+v==0?newnode(L,R):u+v;
    		if (u==x)mul(u,t1);
    		if (v==x)mul(v,t2);
    		add(x,t3);
    		return x;
    	}
    	mul(u,t1),mul(v,t2),add(u,t3);
    	int Mid=(L+R)>>1;
    	t[u].v=F(t[u].v+t[v].v);
    	if (L==R)return u;
    	t[u].ls=merge(t[u].ls,t[v].ls,t[u].tag2,t[v].tag2,F(t[u].tag1+t[v].tag1),L,Mid);
    	t[u].rs=merge(t[u].rs,t[v].rs,t[u].tag2,t[v].tag2,F(t[u].tag1+t[v].tag1),Mid+1,R);
    	t[u].tag2=t[v].tag2=1,t[u].tag1=t[v].tag1=0;
    	return u;
    }
    
    void split(int x,int &y,int L,int R,int pos)
    {
    	if (pos==R||!x)return;
    	if (!y)y=newnode(L,R);
    	pushdown(x,L,R);
    	int Mid=(L+R)>>1;
    	if (pos<=Mid)split(t[x].ls,t[y].ls,L,Mid,pos),swap(t[x].rs,t[y].rs);
    	else split(t[x].rs,t[y].rs,Mid+1,R,pos);
    	pushup(x),pushup(y);
    }
    
    void dfs(int u,int fa=0)
    {
    	rt[u]=0,f[u]=a[u];
    	int tt=0;
    	vector<int>V;
    	for (auto v:e[u])
    		if (v!=fa)
    			dfs(v,u),ckmin(f[u],f[v]),tt++,V.pb(f[v]);
    	for (int i=1;i<=tt;i++)c[i]=V[i-1];
    	sort(c+1,c+1+tt);
    	int tmp1=0,tmp2=0;
    	for (auto v:e[u])
    		if (v!=fa)
    		{
    			tmp2=0;split(rt[v],tmp2,1,m,f[v]);
    			rt[u]=merge(rt[u],rt[v],1,1,0,1,m);
    			tmp1=merge(tmp1,tmp2,1,1,0,1,m);
    		}
    	c[++tt]=m;
    	for (int i=1;i<tt;i++)
    		if (c[i]!=c[i+1])
    			update(rt[u],1,m,c[i]+1,c[i+1],0,inv[i+1]),
    			update(tmp1,1,m,c[i]+1,c[i+1],0,inv[i]);
    	rt[u]=merge(rt[u],tmp1,1,1,0,1,m);
    	update(rt[u],1,m,1,a[u],1,1);
    	update(rt[u],1,m,a[u]+1,m,0,0);
    //	cout<<u<<' '<<t[rt[u]].v<<'\n';
    	ans=F(ans+t[rt[u]].v);
    }
    
    void solve()
    {
    	n=read();int rt=read();
    	ans=0;cnt=0;
    	for (int i=1;i<=n;i++)a[i]=p[i]=read(),inv[i]=poww(i),e[i].clear();
    	for (int i=1;i<n;i++)
    	{
    		int u=read(),v=read();
    		e[u].pb(v),e[v].pb(u);
    	}
    	sort(p+1,p+1+n);
    	m=unique(p+1,p+1+n)-p-1;
    	for (int i=1;i<=n;i++)a[i]=lower_bound(p+1,p+1+m,a[i])-p;
    //	puts("-----");
    	dfs(rt);
    	cout<<ans<<'\n';
    }
    
    signed main()
    {
    	WT solve();
    	return 0;
    }
    
    • @ 2023-6-11 10:15:45

      @大佬,这直接CE

  • -12
    @ 2022-8-3 15:16:01

    拒绝

  • 1

信息

ID
258
时间
5000ms
内存
1024MiB
难度
9
标签
递交数
19
已通过
4
上传者