题目链接
https://www.lydsy.com/JudgeOnline/problem.php?id=3697
题解
点分治:统计过当前根节点满足条件的路径数;
统计每个子树中根到每个节点的路径权值和,将权值和为\(x\)的路径的数目记录在\(f[x][0/1],g[x][0/1]\)中,其中\(f\)记录当前子树的信息,\(g\)记录之前遍历过的子树的信息,\(0,1\)分别记录是否存在休息站;
至于如何记录休息站,可以用桶记录当前路径权值的各前缀来判断;
当前子树的贡献即为\(f[0][0]\times g[0][0]+\sum_{i=-mxdep}^{mxdep} f[i][0]\times g[-i][1]+f[i][1]\times g[-i][0]+f[i][1]\times g[-i][1]\),\(mxdep\)为当前子树的最大深度。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
#include<bits/stdc++.h> #define INF 0x3f3f3f3f using namespace std; typedef long long LL; typedef double db; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-')f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+ch-'0'; ch=getchar(); } return x*f; } const int MAXN=1e5+10; int n,head[MAXN],cnt,sumch[MAXN],d[MAXN],sum,rt,mxch[MAXN],t[MAXN*2],mxdep,dep[MAXN],dis[MAXN]; LL ans,f[MAXN*2][2],g[MAXN*2][2]; bool vis[MAXN]; struct edge { int v,next,val; }e[MAXN*2]; void addedge(int x,int y,int z) { e[++cnt]=(edge){y,head[x],z}; head[x]=cnt; return; } void getroot(int u,int fa) { sumch[u]=1; mxch[u]=0; for(int i=head[u];i;i=e[i].next) { int v=e[i].v; if(v==fa||vis[v])continue; getroot(v,u); sumch[u]+=sumch[v]; mxch[u]=max(mxch[u],sumch[v]); } mxch[u]=max(mxch[u],sum-sumch[u]); if(mxch[u]<mxch[rt])rt=u; return; } void cal(int u,int fa) { mxdep=max(mxdep,dep[u]); if(t[dis[u]])++f[dis[u]][1]; else ++f[dis[u]][0]; ++t[dis[u]]; for(int i=head[u];i;i=e[i].next) { int v=e[i].v; if(v==fa||vis[v])continue; dis[v]=dis[u]+e[i].val; dep[v]=dep[u]+1; cal(v,u); } --t[dis[u]]; return; } void solve(int u) { vis[u]=true; g[n][0]=1; int _max=0; for(int i=head[u];i;i=e[i].next) { int v=e[i].v; if(vis[v])continue; dis[v]=n+e[i].val; dep[v]=1; mxdep=1; cal(v,u); _max=max(_max,mxdep); ans+=(g[n][0]-1)*f[n][0]; for(int j=-mxdep;j<=mxdep;++j) ans+=g[n-j][0]*f[n+j][1]+g[n-j][1]*f[n+j][0]+g[n-j][1]*f[n+j][1]; for(int j=n-mxdep;j<=n+mxdep;++j) { g[j][0]+=f[j][0]; g[j][1]+=f[j][1]; f[j][0]=f[j][1]=0; } } for(int i=n-_max;i<=n+_max;++i) g[i][0]=g[i][1]=0; for(int i=head[u];i;i=e[i].next) { int v=e[i].v; if(vis[v])continue; sum=sumch[v]; rt=0; getroot(v,u); solve(rt); } return; } int main() { n=read(); for(int i=1;i<n;++i) { int x=read(),y=read(),z=read()?1:-1; addedge(x,y,z); addedge(y,x,z); } sum=n;mxch[0]=INF; getroot(1,0); solve(rt); printf("%lld\n",ans); return 0; } |