




#include <algorithm>
#include <iterator>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <iomanip>
#include <bitset>
#include <cctype>
#include <cstdio>
#include <string>
#include <vector>
#include <stack>
#include <cmath>
#include <queue>
#include <list>
#include <map>
#include <set>
#include <cassert> /* ⊂_ヽ
  \\ Λ_Λ 来了老弟
    > ⌒ヽ
   /   へ\
   /  / \\
   レ ノ   ヽ_つ
  / /
  / /|
 ( (ヽ
 | |、\
 | 丿 \ ⌒)
 | |  ) /
'ノ )  Lノ */ using namespace std;
#define lson (l , mid , rt << 1)
#define rson (mid + 1 , r , rt << 1 | 1)
#define debug(x) cerr << #x << " = " << x << "\n";
#define pb push_back
#define pq priority_queue typedef long long ll;
typedef unsigned long long ull;
//typedef __int128 bll;
typedef pair<ll ,ll > pll;
typedef pair<int ,int > pii;
typedef pair<int,pii> p3; //priority_queue<int> q;//这是一个大根堆q
//priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
#define fi first
#define se second
//#define endl '\n' #define boost ios::sync_with_stdio(false);cin.tie(0)
#define rep(a, b, c) for(int a = (b); a <= (c); ++ a)
#define max3(a,b,c) max(max(a,b), c);
#define min3(a,b,c) min(min(a,b), c); const ll oo = 1ll<<;
const ll mos = 0x7FFFFFFF; //
const ll nmos = 0x80000000; //-2147483648
const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f; //
const int mod = 1e9+;
const double esp = 1e-;
const double PI=acos(-1.0);
const double PHI=0.61803399; //黄金分割点
const double tPHI=0.38196601; template<typename T>
inline T read(T&x){
x=;int f=;char ch=getchar();
while (ch<''||ch>'') f|=(ch=='-'),ch=getchar();
while (ch>=''&&ch<='') x=x*+ch-'',ch=getchar();
return x=f?-x:x;
} inline void cmax(int &x,int y){if(x<y)x=y;}
inline void cmax(ll &x,ll y){if(x<y)x=y;}
inline void cmin(int &x,int y){if(x>y)x=y;}
inline void cmin(ll &x,ll y){if(x>y)x=y;} /*-----------------------showtime----------------------*/
const int maxn = 2e5+;
int col[maxn];
ll ans = , sumcol = ;
int sz[maxn],wt[maxn], root, curn;
int vis[maxn];
void findRoot(int u, int fa) {
sz[u] = ;wt[u] = ;
for(int i=; i<mp[u].size(); i++) {
int v = mp[u][i];
if(v == fa || vis[v]) continue;
findRoot(v, u);
sz[u] += sz[v];
wt[u] = max(sz[v], wt[u]);
wt[u] = max(wt[u], curn - sz[u]);
if(wt[u] <= wt[root]) root = u;
// map<int, int> pp;
ll pp[maxn];
int youmeiyou[maxn];
int ss;
void gao(int u, int fa, vector<pii>& vv, int cnt, ll sumfa, ll sum) {
ll res = ;
if(youmeiyou[col[u]] == )
vv.pb(pii(col[u], sz[u])), cnt++, res += pp[col[u]]; youmeiyou[col[u]]++;
ans += sumcol - sumfa - res + 1ll * cnt * sum;
if(youmeiyou[col[ss]] == ) ans += sum - pp[col[ss]];
for(int i=; i<mp[u].size(); i++) {
int v = mp[u][i];
if(fa == v || vis[v]) continue;
gao(v, u, vv, cnt, sumfa + res, sum);
youmeiyou[col[u]] --;
} void solve(int u) {
vis[u] = ;
findRoot(u, -);
ll sum = ;
sumcol = ;
for(int i=; i<mp[u].size(); i++) {
int v = mp[u][i];
if(vis[v]) continue;
ss = u;
gao(v, -, vv, , , sum); for(int j=; j<vv.size(); j++){
int c = vv[j].fi;
if(pp[c])pp[c] += vv[j].se;
else {
pp[c] = vv[j].se;
sumcol += vv[j].se;
sum += sz[v];
} while(!needclear.empty()) {
pp[needclear.front()] = ;
for(int i=; i<mp[u].size(); i++) {
int v = mp[u][i];
if(!vis[v]) {
root = ; wt[] = inf; curn = sz[v];
findRoot(v, -);
int main(){
int n, cas = ;
while(~scanf("%d", &n)) {
memset(vis, , sizeof(vis));
for(int i=; i<=n; i++) scanf("%d", &col[i]);
for(int i=; i<=n; i++) mp[i].clear();
for(int i=; i<n; i++) {
int u,v;
scanf("%d%d", &u, &v);
} ans = ;
root = ; wt[] = inf;
curn = n;
findRoot(, -);
printf("Case #%d: %lld\n", ++cas, ans);
return ;
1 2 3 1 2 3
1 2
1 3
3 4
3 5
4 6

虚树 + 树上差分法:

  对于一种颜色,可以把树分割成许多联通块,同一个联通块内,这种颜色不会产生影响,所以某个点上,某个颜色的影响就是n - size,size是包含这个点的联通块的大小。



#include <bits/stdc++.h>

using namespace std;
#define pb push_back
#define fi first
#define se second
#define debug(x) cerr<<#x << " := " << x << endl;
#define bug cerr<<"-----------------------"<<endl;
#define FOR(a, b, c) for(int a = b; a <= c; ++ a) typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<pii, int>PII; template<class T> void _R(T &x) { cin >> x; }
void _R(int &x) { scanf("%d", &x); }
void _R(ll &x) { scanf("%lld", &x); }
void _R(double &x) { scanf("%lf", &x); }
void _R(char &x) { scanf(" %c", &x); }
void _R(char *x) { scanf("%s", x); }
void R() {}
template<class T, class... U> void R(T &head, U &... tail) { _R(head); R(tail...); } template<typename T>
inline T read(T&x){
x=;int f=;char ch=getchar();
while (ch<''||ch>'') f|=(ch=='-'),ch=getchar();
while (ch>=''&&ch<='') x=x*+ch-'',ch=getchar();
return x=f?-x:x;
} const ll inf = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9+; /**********showtime************/
const int maxn = 2e5+;
int col[maxn],vis[maxn];
int sz[maxn], dfn[maxn], dp[maxn], tim;
int fa[maxn][];
ll fen[maxn],ans; void dfs(int u, int o) {
sz[u] = ; dfn[u] = ++tim;
fa[u][] = o;
dp[u] = dp[o] + ;
for(int i=; i<; i++)
fa[u][i] = fa[fa[u][i-]][i-];
for(int v : mp[u]) {
if(v == o) continue;
dfs(v, u);
sz[u] += sz[v];
} int lca(int u, int v) {
if(dp[u] < dp[v]) swap(u, v); for(int i=; i>=; i--) {
if(dp[fa[u][i]] >= dp[v])
u = fa[u][i];
if(u == v) return u; for(int i=; i>=; i--) {
if(fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][];
} bool cmp(int x, int y) {
return dfn[x] < dfn[y];
int used[maxn];
int nsz[maxn];
int curcol;
int n;
int cdp[maxn];
void gaoNewSz(int u, int o) {
ll s = ;
cdp[u] = ;
for(int v : xu_mp[u]) {
if(v == o) continue;
gaoNewSz(v, u);
if(col[v] == curcol)
cdp[u] += sz[v];
else cdp[u] += cdp[v];
nsz[u] = n - (sz[u] - cdp[u]);
void gaoSub(int u, int fa, int val) {
int w = val;
if(col[u] == curcol) {
fen[u] -= val;
else if(col[fa] == curcol || u == )
fen[u] += nsz[u];
w = nsz[u];
} for(int v : xu_mp[u]) {
if(v == fa) continue;
if(col[u] == curcol)gaoSub(v, u, );
else gaoSub(v, u, w);
} //建立虚树
void build(vector <int> & xu) {
sort(xu.begin(), xu.end(), cmp);
queue<int>que; for(int i=; i<xu.size(); i++) {
int u = xu[i];
if(st.size() <= ) st.push(u);
else {
int x = st.top(); st.pop();
int o = lca(x, u);
if(o == x) {
while(!st.empty()) {
int y = st.top(); st.pop(); if(dfn[y] > dfn[o]) {
if(used[y] == ) used[y] = , que.push(y);
x = y;
else if(dfn[y] == dfn[o]) {
if(used[y] == ) used[y] = , que.push(y);
else {
if(used[o] == ) used[o] = , que.push(o);
while(st.size() > ) {
int u = st.top(); st.pop();
int v = st.top();
// if(used[u] == 0) used[u] = 1, que.push(u);
if(used[v] == ) used[v] = , que.push(v);
while(!st.empty())st.pop(); gaoNewSz(, );
gaoSub(, , ); while(!que.empty()) {
int u = que.front();
used[u] = ;
} //树上差分,最后的更新
void pushdown(int u, int fa, ll val) {
ans += fen[u] + val + n;
val += fen[u];
for(int v : mp[u]) {
if(v == fa) continue;
pushdown(v, u, val);
} int main(){
int cas = ;
while(~scanf("%d", &n)){
ans = ;tim = ;
for(int i=; i<=n; i++){
fen[i] = ;
vis[i] = ;
dp[i] = ;
for(int i=; i<=n; i++) {
vis[col[i]] = ;
for(int i=; i<n; i++) {
int u,v;
read(u); read(v);
} dfs(, ); for(int i=; i<maxn; i++) {
if(vis[i]) {
if(col[] != i) xu.pb();
for(int v : node[i]) {
for(int k : mp[v]) {
if(col[k] != i && dp[k] > dp[v])
curcol = i;
pushdown(, , );
printf("Case #%d: %lld\n", ++cas, (ans - n )/ );
return ;


void insert(int x) {
if(top == ) {s[++top] = x; return ;}
int lca = LCA(x, s[top]);
if(lca == s[top]){ s[++top] = x;return ;}
while(top > && dfn[s[top - ]] >= dfn[lca]) add_edge(s[top - ], s[top]), top--;
if(lca != s[top]) add_edge(lca, s[top]), s[top] = lca;//
s[++top] = x;

