$n$ 个点的树,数一条链上有多少不同的点




记 $(cu,cv)$ 为当前的链,$(qu,qv)$ 为当前询问的链,维护一个 $vis$ 数组表示“当前点在/不在当前链上”,每次暴力从 $cu,qu$ 爬到他们的 lca,从 $cv,qv$ 爬到他们的 lca,特盘一下 $qu,qv$ 的 lca 就可以了

#include <bits/stdc++.h>
#define LL long long
using namespace std;
#define rep(i, s, t) for (register int i = (s), i__end = (t); i <= i__end; ++i)
#define dwn(i, s, t) for (register int i = (s), i__end = (t); i >= i__end; --i)
inline int read() {
int x = , f = ; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f;
for (; isdigit(ch); ch = getchar()) x = * x + ch - '';
return x * f;
const int maxn = ;
int n, m, b[maxn], a[maxn], blk, bcnt;
vector<int> G[maxn];
int fa[maxn], dep[maxn];
namespace splca {
int size[maxn], top[maxn];
void dfs1(int x) {
size[x] = ;
for(auto to : G[x]) {
if(to == fa[x]) continue;
fa[to] = x;
dfs1(to); size[x] += size[to];
void dfs2(int x, int col) {
int k = ; top[x] = col;
for(auto to : G[x])
if(to != fa[x] && size[to] > size[k]) k = to;
if(!k) return;
dfs2(k, col);
for(auto to : G[x])
if(to != fa[x] && to != k) dfs2(to, to);
int lca(int x, int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
return dep[x] < dep[y] ? x : y;
void lca_init() {splca::dfs1(); splca::dfs2(, );}
int lca(int x, int y) {return splca::lca(x, y);} int size[maxn], bl[maxn], q[maxn], top;
int dfs2(int x) {
int cur = ;
for(auto to : G[x]) {
if(to == fa[x]) continue;
dep[to] = dep[x] + ;
cur += dfs2(to);
if(cur >= blk) {
while(cur--) bl[q[--top]] = bcnt;
q[++top] = x;
return cur + ;
int ans[maxn], vis[maxn], inq[maxn];
struct Ques {
int u, v, fl, fr, id;
bool operator < (const Ques &b) const {
return fl == b.fl ? fr < b.fr : fl < b.fl;
int now;
void move(int &x) {
if(--vis[a[x]] == ) now--;
else if(++vis[a[x]] == ) now++;
inq[x] ^= ;
x = fa[x];
int main() {
n = read(), m = read();
blk = sqrt(n);
rep(i, , n) b[i] = a[i] = read();
sort(b + , b + n + );
rep(i, , n) a[i] = lower_bound(b+, b+n+, a[i]) - b;
rep(i, , n) {
int u = read(), v = read();
G[u].push_back(v); G[v].push_back(u);
} lca_init(); dep[] = ; dfs2();
while(top) bl[q[--top]] = bcnt;
rep(i, , m) {
int v = read(), u = read();
if(bl[v] > bl[u]) swap(u, v);
qs[i] = (Ques){v, u, bl[v], bl[u], i};
//cout << v << " " << u << " " << bl[v] << " " << bl[u] << endl;
sort(qs + , qs + m + );
int cu = , cv = ;
rep(i, , m) {
int nu = qs[i].u, nv = qs[i].v;
int anc = lca(cu, nu);
while(cu != anc) move(cu);
while(nu != anc) move(nu);
anc = lca(cv, nv);
while(cv != anc) move(cv);
while(nv != anc) move(nv);
cv = qs[i].v, cu = qs[i].u;
anc = lca(cv, cu);
ans[qs[i].id] = now + (!vis[a[anc]]);
rep(i, , m) printf("%d\n",ans[i]);

