1 条题解

  • 1
    @ 2022-8-17 22:19:48
    #include <bits/stdc++.h>
    using namespace std;
    
    const int MAXN=5e5+5;
    
    int n,d,m;
    int cost[MAXN];
    bool vis[MAXN];
    vector<int> q[MAXN];
    
    int f[MAXN][31];
    int g[MAXN][31];
    
    void dfs(int x,int pr)
    {
        if(vis[x])
            f[x][0]=g[x][0]=cost[x];
        else f[x][0]=0,g[x][0]=0;
        for(int i=1;i<=d;i++)
            f[x][i]=cost[x];
        for(int i=0;i<q[x].size();i++)
        {
            int nx=q[x][i];
            if(nx==pr)continue;
            dfs(nx,x);
        }
        for(int i=0;i<q[x].size();i++)
        {
            int nx=q[x][i];
            if(nx==pr)continue;
            for(int j=d;j>=0;j--){
                f[x][j]=min(min(f[x][j]+g[nx][j],g[x][j+1]+f[nx][j+1]),f[x][j]+f[nx][j+1]);
            }
            for(int j=d;j>=0;j--)
                f[x][j]=min(f[x][j+1],f[x][j]);
            g[x][0]=f[x][0];
            for(int j=1;j<=d+1;j++)
                g[x][j]=g[x][j]+g[nx][j-1];
            for(int j=1;j<=d+1;j++)
                g[x][j]=min(g[x][j],g[x][j-1]);
        }
        for(int j=d;j>=0;j--)
            f[x][j]=min(f[x][j+1],f[x][j]);
        for(int j=1;j<=d+1;j++)
            g[x][j]=min(g[x][j],g[x][j-1]);
    }
    
    int main()
    {
        scanf("%d%d",&n,&d);
        for(int i=1;i<=n;i++)
            scanf("%d",&cost[i]);
        scanf("%d",&m);
        for(int i=1;i<=m;i++){
            int x;scanf("%d",&x),vis[x]=1;
        }
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            q[x].push_back(y);
            q[y].push_back(x);
        }
        memset(f,0x3f,sizeof(f));
        dfs(1,1);
        int ans=1e9;
        for(int i=0;i<=d;i++)
            ans=min(f[1][i],ans);
        printf("%d\n",ans);
        return 0;
    }
    
    • 1

    信息

    ID
    2203
    时间
    2000ms
    内存
    250MiB
    难度
    6
    标签
    递交数
    3
    已通过
    3
    上传者