「SPOJ10707」Count on a tree II





#include <algorithm>
#include <cstdio>
#include <cmath>
#define rg register
#define file(x) freopen(x".in", "r", stdin), freopen(x".out", "w", stdout)
using namespace std;
template < class T > inline void read(T& s) {
s = 0; int f = 0; char c = getchar();
while ('0' > c || c > '9') f |= c == '-', c = getchar();
while ('0' <= c && c <= '9') s = s * 10 + c - 48, c = getchar();
s = f ? -s : s;
} const int _ = 40005, __ = 1e5 + 5; int tot, head[_]; struct Edge { int ver, nxt; } edge[_ << 1];
inline void Add_edge(int u, int v) { edge[++tot] = (Edge) { v, head[u] }; head[u] = tot; } int n, q, a[_], X0, X[_];
int fir[_], las[_], vis[_], dep[_], fa[17][_];
int len, ord[_ << 1], m, pos[_ << 1];
int ans, cnt[_], res[__];
struct node { int l, r, lca, id; } t[__];
inline bool cmp(const node& x, const node& y)
{ return pos[x.l] != pos[y.l] ? pos[x.l] < pos[y.l] : ((pos[x.l] & 1) ? x.r < y.r : y.r < x.r); } inline void dfs(int u, int f) {
fir[u] = ++len, ord[len] = u;
dep[u] = dep[f] + 1, fa[0][u] = f;
for (rg int i = 1; i <= 16; ++i) fa[i][u] = fa[i - 1][fa[i - 1][u]];
for (rg int i = head[u]; i; i = edge[i].nxt) if (edge[i].ver != f) dfs(edge[i].ver, u);
las[u] = ++len, ord[len] = u;
} inline int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (rg int i = 16; ~i; --i) if (dep[fa[i][x]] >= dep[y]) x = fa[i][x];
if (x == y) return x;
for (rg int i = 16; ~i; --i) if (fa[i][x] != fa[i][y]) x = fa[i][x], y = fa[i][y];
return fa[0][x];
} inline void calc(int x) { vis[x] ? ans -= !--cnt[a[x]] : ans += !cnt[a[x]]++, vis[x] ^= 1; } int main() {
read(n), read(q);
for (rg int i = 1; i <= n; ++i) read(a[i]), X[i] = a[i];
sort(X + 1, X + n + 1);
X0 = unique(X + 1, X + n + 1) - X - 1;
for (rg int i = 1; i <= n; ++i) a[i] = lower_bound(X + 1, X + X0 + 1, a[i]) - X;
for (rg int x, y, i = 1; i < n; ++i) read(x), read(y), Add_edge(x, y), Add_edge(y, x);
dfs(1, 0);
for (rg int x, y, lca, i = 1; i <= q; ++i) {
read(x), read(y), lca = LCA(x, y);
if (fir[x] > fir[y]) swap(x, y);
if (x == lca)
t[i].l = fir[x], t[i].r = fir[y], t[i].lca = 0;
t[i].l = las[x], t[i].r = fir[y], t[i].lca = lca;
t[i].id = i;
m = sqrt(1.0 * len);
for (rg int i = 1; i <= len; ++i) pos[i] = (i - 1) / m + 1;
sort(t + 1, t + q + 1, cmp);
for (rg int l = 1, r = 0, i = 1; i <= q; ++i) {
while (l > t[i].l) calc(ord[--l]);
while (r < t[i].r) calc(ord[++r]);
while (l < t[i].l) calc(ord[l++]);
while (r > t[i].r) calc(ord[r--]);
if (t[i].lca) calc(t[i].lca);
res[t[i].id] = ans;
if (t[i].lca) calc(t[i].lca);
for (rg int i = 1; i <= q; ++i) printf("%d\n", res[i]);
return 0;

