1 条题解
-
1
#include <bits/stdc++.h> using namespace std; const int MAXN=5e5+5; int n,d,m; int cost[MAXN]; bool vis[MAXN]; vector<int> q[MAXN]; int f[MAXN][31]; int g[MAXN][31]; void dfs(int x,int pr) { if(vis[x]) f[x][0]=g[x][0]=cost[x]; else f[x][0]=0,g[x][0]=0; for(int i=1;i<=d;i++) f[x][i]=cost[x]; for(int i=0;i<q[x].size();i++) { int nx=q[x][i]; if(nx==pr)continue; dfs(nx,x); } for(int i=0;i<q[x].size();i++) { int nx=q[x][i]; if(nx==pr)continue; for(int j=d;j>=0;j--){ f[x][j]=min(min(f[x][j]+g[nx][j],g[x][j+1]+f[nx][j+1]),f[x][j]+f[nx][j+1]); } for(int j=d;j>=0;j--) f[x][j]=min(f[x][j+1],f[x][j]); g[x][0]=f[x][0]; for(int j=1;j<=d+1;j++) g[x][j]=g[x][j]+g[nx][j-1]; for(int j=1;j<=d+1;j++) g[x][j]=min(g[x][j],g[x][j-1]); } for(int j=d;j>=0;j--) f[x][j]=min(f[x][j+1],f[x][j]); for(int j=1;j<=d+1;j++) g[x][j]=min(g[x][j],g[x][j-1]); } int main() { scanf("%d%d",&n,&d); for(int i=1;i<=n;i++) scanf("%d",&cost[i]); scanf("%d",&m); for(int i=1;i<=m;i++){ int x;scanf("%d",&x),vis[x]=1; } for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); q[x].push_back(y); q[y].push_back(x); } memset(f,0x3f,sizeof(f)); dfs(1,1); int ans=1e9; for(int i=0;i<=d;i++) ans=min(f[1][i],ans); printf("%d\n",ans); return 0; }
- 1
信息
- ID
- 2203
- 时间
- 2000ms
- 内存
- 250MiB
- 难度
- 6
- 标签
- 递交数
- 3
- 已通过
- 3
- 上传者