【LOJ3044】【ZJOI2019】Minimax 搜索

题目链接

【LOJ3044】【ZJOI2019】Minimax 搜索

做法

计算 $ W(S) <= i $ 比 $ W(S) = i $ 容易得多,设最后结果为 $ sum[i] - sum[i - 1] $ ,其中 $ sum[n] = 2^{count~leaf} - 1 $ 。
先做一次 $ dfs $ 找到决策路径。对于每一个 $ i \in [l, r] $ ,令 $ dp[i] $ 表示 $ i $ 的子树有多少种方案是无法改变根节点的值的,用总集合数减去就得到可以改变的方案数。
然后得到一个 $ O(n \times (r - l + 1)) $ 的算法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int N = 200010;
int n, L, R;
int cnt = 0, hed[N], to[N + N], nxt[N + N];
int f[N], dep[N], sz[N];
int ans[N], value, ba[N];

inline int Max(int x, int y) { return x > y ? x : y; }
inline int Min(int x, int y) { return x < y ? x : y; }
inline int add(int x, int y) { return x + y < mod ? x + y : x + y - mod; }
inline int sub(int x, int y) { return x - y < 0 ? x - y + mod : x - y; }
inline int mul(int x, int y) { return (int)((ll)x * y % mod); }
inline void addedge(int x, int y) {
to[++cnt] = y, nxt[cnt] = hed[x], hed[x] = cnt;
}
void init(int u, int ff) {
dep[u] = dep[ff] + 1; bool Leaf = 1;
if(dep[u] & 1) f[u] = 1; else f[u] = N;
for(int i = hed[u]; i; i = nxt[i]) if(to[i] ^ ff) {
init(to[i], u), sz[u] += sz[to[i]], Leaf = 0;
if(dep[u] & 1) f[u] = Max(f[u], f[to[i]]);
else f[u] = Min(f[u], f[to[i]]);
}
if(Leaf) sz[u] = 1, f[u] = u;
}
int gmx(int u, int ff, int w) {
bool Leaf = 1; int res = 1;
for(int i = hed[u]; i; i = nxt[i]) if(to[i] ^ ff) {
Leaf = 0;
if(dep[u] & 1) res = mul(res, gmx(to[i], u, w));
else res = mul(res, sub(ba[sz[to[i]]], gmx(to[i], u, w)));
}
if(Leaf) {
if(f[u] <= value) ++res; if(f[u] + w <= value) ++res; return res - 1;
}
if(dep[u] & 1) return res; return sub(ba[sz[u]], res);
}
int gmn(int u, int ff, int w) {
bool Leaf = 1; int res = 1;
for(int i = hed[u]; i; i = nxt[i]) if(to[i] ^ ff) {
Leaf = 0;
if(dep[u] & 1) res = mul(res, sub(ba[sz[to[i]]], gmn(to[i], u, w)));
else res = mul(res, gmn(to[i], u, w));
}
if(Leaf) {
if(f[u] >= value) ++res; if(f[u] - w >= value) ++res; return res - 1;
}
if(dep[u] & 1) return sub(ba[sz[u]], res); return res;
}
int dfs(int u, int ff, int w) {
int res = 1;
for(int i = hed[u]; i; i = nxt[i]) if(to[i] ^ ff) {
if(f[u] == f[to[i]]) res = mul(res, dfs(to[i], u, w));
else if(dep[u] & 1) res = mul(res, gmx(to[i], u, w));
else res = mul(res, gmn(to[i], u, w));
}
return res;
}
int calc(int x) {
if(!x) return 0; if(x == n) return sub(ba[sz[1]], 1);
return sub(ba[sz[1]], dfs(1, 0, x));
}
int main() {
scanf("%d%d%d", &n, &L, &R);
for(int i = 1, x, y; i < n; i++)
scanf("%d%d", &x, &y), addedge(x, y), addedge(y, x);
if(R - L <= 50) {
ba[0] = 1;
for(int i = 1; i <= n; i++) ba[i] = add(ba[i - 1], ba[i - 1]);
init(1, 0), value = f[1];
for(int i = L - 1; i <= R; i++) ans[i] = calc(i);
for(int i = R; i >= L; i--) ans[i] = sub(ans[i], ans[i - 1]);
for(int i = L; i <= R; i++) printf("%d ", ans[i]);
}
return 0;
}

发现每个叶子节点只会被改变一次,然后这个问题可以变成一个动态DP。
考虑决策路径上的节点是无用的,将树进行重链剖分,以每一个决策路径的点作为根进行动态DP。
卡常数可以用向量代替矩阵。
时间复杂度 $ O(n \log n) $ 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <bits/stdc++.h>
#define pb push_back
#define mp make_pair
#define fst first
#define snd second
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
const int mod = 998244353;
const int N = 200010;
inline int add(const int &x, const int &y) { return x + y < mod ? x + y : x + y - mod; }
inline int sub(const int &x, const int &y) { return x - y < 0 ? x - y + mod : x - y; }
inline int mul(const int &x, const int &y) { return (int)((ll)x * y % mod); }
int ksm(int x, int y = mod - 2) {
int ss = 1; for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ss = mul(ss, x);
return ss;
}

struct Pair { int k, b; Pair(int K = 0, int B = 0) : k(K), b(B) {} };
inline Pair operator+(const Pair &x, const Pair &y) {
return Pair(mul(x.k, y.k), add(mul(x.k, y.b), x.b));
}
inline pii operator+(pii x, const int &y) {
if(y) x.fst = mul(x.fst, y); else ++x.snd; return x;
}
inline pii operator-(pii x, const int &y) {
if(y) x.fst = mul(x.fst, ksm(y)); else --x.snd; return x;
}
inline int gval(pii x) { return x.snd ? 0 : x.fst; }

int n, lb, rb, sum = 1, ans[N];
bool leaf[N]; int sz[N], a[N], w[N], dep[N], son[N], key[N];
int fa[N], dfn[N], idx = 0, dp[N];
bool type[N]; pii pr[N]; int top[N], ed[N];
vector<int> e[N], chs[N];

struct TR {
struct P {
int lf, rf; Pair w;
P(int Lf = 0, int Rf = 0, Pair W = Pair(1, 0)) :
lf(Lf), rf(Rf), w(W) {}
}; P t[N * 4];
int tot, n, rt;
void pushup(int u) { t[u].w = t[t[u].lf].w + t[t[u].rf].w; }
void build(int &u, int l, int r) {
u = ++tot, t[u].w = Pair(1, 0); if(l >= r) return ;
int mid = (l + r) >> 1;
build(t[u].lf, l, mid), build(t[u].rf, mid + 1, r);
}
void init(int size) { rt = tot = 0, n = size, build(rt, 1, n); }
void mdy(int u, int l, int r, int x, Pair w) {
if(l >= r) { t[u].w = w; return ; }
int mid = (l + r) >> 1;
if(x <= mid) mdy(t[u].lf, l, mid, x, w);
else mdy(t[u].rf, mid + 1, r, x, w);
pushup(u);
}
void Mdy(int x, Pair w) { mdy(1, 1, n, x, w); }
Pair qry(int u, int l, int r, int L, int R) {
if(l == L && r == R) return t[u].w;
int mid = (l + r) >> 1;
if(R <= mid) return qry(t[u].lf, l, mid, L, R);
if(L > mid) return qry(t[u].rf, mid + 1, r, L, R);
return qry(t[u].lf, l, mid, L, mid) +
qry(t[u].rf, mid + 1, r, mid + 1, R);
}
int Qry(int l, int r) { return qry(1, 1, n, l, r).b; }
}; TR tr;

void init(int u, int ff) {
leaf[u] = 1, dep[u] = dep[ff] + 1; w[u] = (dep[u] & 1) ? 0 : N;
sz[u] = a[u] = 1;
for(auto v : e[u]) if(v ^ ff) {
leaf[u] = 0, init(v, u), sz[u] += sz[v], a[u] = mul(a[u], a[v]);
if(sz[son[u]] < sz[v]) son[u] = v;
if((dep[u] & 1) && w[v] > w[u]) w[u] = w[v], key[u] = v;
else if(!(dep[u] & 1) && w[v] < w[u]) w[u] = w[v], key[u] = v;
}
if(leaf[u]) a[u] = 2, w[u] = u;
}
void dfs2(int u, int ff, int tp, bool flag, bool opt) {
fa[u] = ff, dfn[u] = ++idx;
type[u] = flag, pr[u] = mp(1, 0), top[u] = tp;
if(leaf[u]) {
if(opt) {
dp[u] = 2 * (w[u] <= w[1]);
if(w[u] <= w[1]) chs[w[1] - w[u] + 1].pb(u);
}
else {
dp[u] = 2 * (w[u] >= w[1]);
if(w[u] >= w[1]) chs[w[u] - w[1] + 1].pb(u);
}
tr.Mdy(dfn[u], Pair(0, dp[u]));
}
if(son[u]) dfs2(son[u], u, tp, flag ^ 1, opt), ed[u] = ed[son[u]];
else ed[u] = u;
for(auto v : e[u]) if(v != ff && v != son[u]) {
dfs2(v, u, v, flag ^ 1, opt);
if(flag) pr[u] = pr[u] + dp[v];
else pr[u] = pr[u] + sub(a[v], dp[v]);
}
if(son[u]) {
if(flag) {
dp[u] = mul(gval(pr[u]), dp[son[u]]);
tr.Mdy(dfn[u], Pair(gval(pr[u]), 0));
}
else {
dp[u] = mul(gval(pr[u]), sub(a[son[u]], dp[son[u]]));
dp[u] = sub(a[u], dp[u]);
tr.Mdy(dfn[u], Pair(gval(pr[u]), sub(a[u], mul(gval(pr[u]), a[son[u]]))));
}
}
}
void dfs1(int u, int ff) {
if(!key[u]) return ;
dfs1(key[u], u);
for(auto v : e[u]) if(v != ff && v != key[u])
dfs2(v, u, v, 0, dep[u] & 1), sum = mul(sum, dp[v]), fa[v] = 0;
}
void modify(int x) {
tr.Mdy(dfn[x], Pair(0, sub(dp[x], 1)));
int tmp = tr.Qry(dfn[top[x]], dfn[ed[x]]);
x = top[x];
for(; fa[x];) {
int f = fa[x];
if(type[f]) pr[f] = (pr[f] + tmp) - dp[x];
else pr[f] = (pr[f] + sub(a[x], tmp)) - sub(a[x], dp[x]);
dp[x] = tmp, x = f;
if(type[x]) tr.Mdy(dfn[x], Pair(gval(pr[x]), 0));
else tr.Mdy(dfn[x], Pair(gval(pr[x]), sub(a[x], mul(gval(pr[x]), a[son[x]]))));
tmp = tr.Qry(dfn[top[x]], dfn[ed[x]]), x = top[x];
}
sum = mul(mul(sum, tmp), ksm(dp[x])), dp[x] = tmp;
}
int main() {
scanf("%d%d%d", &n, &lb, &rb);
for(int i = 1, x, y; i < n; i++)
scanf("%d%d", &x, &y), e[x].pb(y), e[y].pb(x);
init(1, 0), tr.init(n), dfs1(1, 0);
for(int i = 1; i <= n - 1; i++) {
for(auto x : chs[i]) modify(x);
ans[i] = sub(a[1], sum);
}
ans[n] = a[1] - 1;
for(int i = n; i; i--) ans[i] = sub(ans[i], ans[i - 1]);
for(int i = lb; i <= rb; i++) printf("%d ", ans[i]);
return 0;
}