1 条题解

  • 0
    @ 2021-11-19 8:23:23

    题意就是已知一个点 aabb 需要满足与 aa 的距离不超过 kkcca,ba,b 的公共子树中的一个数,求方案。

    显然 bb 要么是 aa 的祖先,要么是 aa 的子节点。

    如果是祖先,那么答案就是 min{k,depa1}×(siza1)\min\{ k,dep_a-1 \} \times (siz_a-1)

    如果是子节点,答案就是 sizb1\sum siz_b-1

    考虑下面这个东西怎么统计。

    直接用权值线段树合并,权值就是深度,查询区间和就可以了。


    代码
    #include<bits/stdc++.h>
    #define mid (l+r>>1)
    #define int long long
    using namespace std;
    typedef long long ll;
    const int N=3e5+5;
    int dep[N],siz[N],cnt,ans[N],rt[N],n;
    vector<int> e[N];
    vector<pair<int,int> > que[N];
    struct segtree{int l,r,sum;}tr[N*40];
    void dfs(int x,int fa){
        dep[x]=dep[fa]+1;
        siz[x]=1;
        for(int y:e[x]){
            if(y==fa)continue;
            dfs(y,x);
            siz[x]+=siz[y];
        }
    }
    void add(int &p,int l,int r,int x,int z){
        if(!p)p=++cnt;
        if(l==r){
            tr[p].sum+=z;
            return;
        }
        if(x<=mid)add(tr[p].l,l,mid,x,z);
        else add(tr[p].r,mid+1,r,x,z);
        tr[p].sum=tr[tr[p].l].sum+tr[tr[p].r].sum;
    }
    void he(int &p,int q,int l,int r){
        if(!p||!q){p=p+q;return;}
        if(l==r){
            tr[p].sum+=tr[q].sum;
            return;
        }
        he(tr[p].l,tr[q].l,l,mid);
        he(tr[p].r,tr[q].r,mid+1,r);
        tr[p].sum=tr[tr[p].l].sum+tr[tr[p].r].sum;
    }
    int ask(int p,int l,int r,int L,int R){
        if(!p)return 0;
        if(l>=L&&r<=R)return tr[p].sum;
    	return (L<=mid?ask(tr[p].l,l,mid,L,R):0)+(R>mid?ask(tr[p].r,mid+1,r,L,R):0); 
    }
    void dfs1(int x,int fa){
        for(int y:e[x]){
            if(y==fa)continue;
            dfs1(y,x);
            he(rt[x],rt[y],1,n);
        }
        for(auto i:que[x])ans[i.second]+=ask(rt[x],1,n,dep[x]+1,dep[x]+i.first);
        add(rt[x],1,n,dep[x],siz[x]-1);
    }
    int q;
    signed main() {
        scanf("%lld%lld",&n,&q);
        for(int i=1;i<n;i++){
            int x,y;
            scanf("%lld%lld",&x,&y);
            e[x].push_back(y);
            e[y].push_back(x);
        }
        dfs(1,0);
        for(int i=1;i<=q;i++){
            int x,y;
            scanf("%lld%lld",&x,&y);
            que[x].emplace_back(y,i);
            ans[i]=(siz[x]-1)*min(dep[x]-1,y);
        }
        dfs1(1,0);
        for(int i=1;i<=q;i++)printf("%lld\n",ans[i]);
        return 0;
    }
    
    • 1

    信息

    ID
    1
    时间
    2000ms
    内存
    512MiB
    难度
    10
    标签
    递交数
    2
    已通过
    1
    上传者