SPOJ COT3 Combat on a tree(Trie树、线段树的合并)
Alice and Bob are playing a game on a tree of n nodes.Each node is either black or white initially.
They take turns to do the following operation:
Choose a white node v from the current tree;
Color all white nodes on Path(1,v) to black.
The player who takes the last turn wins.
Now Alice takes the first turn.Help her find out if she can win when they both use optimal strategy.
The first line of input contains a integer n representing the number of nodes in the tree. 1<=n<=100000
The second line contains n intergers c1,c2,..cn.0<=ci<=1.
ci=0 means the ith node is white initially and ci=1 means black.
Next n-1 lines describes n-1 edges in the tree.Each line contains two integers u and v,means there is a edge connecting u and v.
If Alice can't win print -1.
Otherwise determine all the nodes she can choose in the first turn in order to win.Print them in ascending order.
给定一棵n个点的有根树,每个点是黑的或者白的。 两个人在树上博弈,轮流进行以下操作: 选择一个当前为白色的点u,把u到根路径上的所有点涂黑。不能操作者输,判断两人都用最优策略进行游戏时的胜负情况,并输出第一个人第一步所有可行的决策。
---g[u][w]=g[v][w]+sigma{dp[v]|v是u的儿子且v!=branch[w]} // 注释:branch[w]=v即,v是u的儿子且v是w的祖先
---g[u][w]=g[v][w]+g[u][u]+dp[branch[w]] // 注释:这一行是上一行的解,其中g[u][u]=sigma{dp[v]|v是u的儿子}
复杂度$O(n log^2n)$
传递标记/询问mex/合并树 都是自顶向下的,不矛盾
$O(n logn)$
代码(C++14 0.55S):
- #include <bits/stdc++.h>
- using namespace std;
- #define FOR(i, n) for(int i = 0; i < n; ++i)
- const int MAXV = ;
- const int MAXE = MAXV << ;
- const int white = ;
- int max_log;
- int head[MAXV], color[MAXV], ecnt;
- int to[MAXE], nxt[MAXE];
- int n;
- void initGraph() {
- memset(head + , -, n * sizeof(int));
- ecnt = ;
- }
- void add_edge(int u, int v) {
- to[ecnt] = v; nxt[ecnt] = head[u]; head[u] = ecnt++;
- to[ecnt] = u; nxt[ecnt] = head[v]; head[v] = ecnt++;
- }
- struct Node {
- Node* go[];
- int size, txor, mex;
- };
- Node statePool[ * MAXV];
- Node *nil, *leaf;
- Node* stk[ * MAXV];
- int ncnt, top;
- Node *new_node() {
- Node* t = top ? stk[--top] : &statePool[ncnt++];
- FOR(i, ) t->go[i] = nil;
- t->size = t->txor = t->mex = ;
- return t;
- }
- void remove(Node *t) {
- stk[top++] = t;
- }
- void initTree() {
- nil = statePool;
- FOR(i, ) nil->go[i] = nil;
- ncnt = ;
- leaf = new_node();
- leaf->mex = leaf->size = ;
- }
- void pushdown(Node *t, int k) {
- if(k > ) {
- if((t->txor >> (k - )) & ) swap(t->go[], t->go[]);
- FOR(i, ) t->go[i]->txor ^= t->txor;
- t->txor = ;
- }
- }
- void update(Node *t, int k) {
- if(k > ) {
- int size = << (k - );
- t->mex = (t->go[]->size < size ? t->go[]->mex : size + t->go[]->mex);
- t->size = t->go[]->size + t->go[]->size;
- }
- }
- Node* merge(Node *a, Node *b, int k) {
- if(a == nil) return b;
- if(b == nil) return a;
- if(a == leaf && b == leaf) return leaf;
- Node *res = new_node();
- pushdown(a, k), pushdown(b, k);
- FOR(i, ) res->go[i] = merge(a->go[i], b->go[i], k - );
- update(res, k);
- remove(a), remove(b);
- return res;
- }
- void insert(Node* &t, int k, int val) {
- if(k == ) t = leaf;
- else {
- if(t == nil) t = new_node();
- pushdown(t, k);
- insert(t->go[(val >> (k - )) & ], k - , val);
- update(t, k);
- }
- }
- Node *root[MAXV];
- int dp[MAXV];
- void dfs(int u, int f) {
- int tmp = ;
- for(int p = head[u]; ~p; p = nxt[p]) {
- int v = to[p];
- if(v != f) dfs(v, u), tmp ^= dp[v];
- }
- if(color[u] == white) insert(root[u], max_log, tmp);
- for(int p = head[u]; ~p; p = nxt[p]) {
- int v = to[p];
- if(v == f) continue;
- root[v]->txor ^= tmp ^ dp[v];
- root[u] = merge(root[u], root[v], max_log);
- }
- dp[u] = root[u]->mex;
- }
- vector<int> ans;
- void dfs_ans(int u, int f, int sg) {
- int tmp = ;
- for(int p = head[u]; ~p; p = nxt[p]) {
- int v = to[p];
- if(v != f) tmp ^= dp[v];
- }
- if(color[u] == white && (sg ^ tmp) == ) ans.push_back(u);
- for(int p = head[u]; ~p; p = nxt[p]) {
- int v = to[p];
- if(v != f) dfs_ans(v, u, sg ^ tmp ^ dp[v]);
- }
- }
- int main() {
- scanf("%d", &n);
- for(int i = ; i <= n; ++i) scanf("%d", &color[i]);
- initGraph();
- for(int i = , u, v; i < n; ++i) {
- scanf("%d%d", &u, &v);
- add_edge(u, v);
- }
- initTree();
- while(( << max_log) <= n) ++max_log;
- for(int i = ; i <= n; ++i) root[i] = nil;
- dfs(, );
- dfs_ans(, , );
- sort(ans.begin(), ans.end());
- for(int x : ans) printf("%d\n", x);
- }
