【CodeChef】Count on a Treap

题目链接

【CodeChef】Count on a Treap

做法

将元素按权值排序,再按键值建笛卡尔树,得到的树就是原 $ Treap $ 。
树上两个点的距离等于两个点的深度之和减去它们 $ LCA $ 深度的两倍。
考虑如何计算两个点的 $ LCA $。根据笛卡尔树的性质,任意点对 $ x, y (x \leq y) $ 的 $ LCA $ 为序列 $ [x, y] $ 中的键值最大值所在点的编号。
考虑如何计算一个点的深度。一个点的深度等于从他开始的前缀/后缀键值最大值个数,可以用线段树维护(update 时间复杂度会因在线段树中二分查找前缀/后缀键值最大值个数而多一个 $ \log $ )。
时间复杂度为 $ O(n \log^2 n) $ 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 400010;
int n, lca; ll value;
int tot = 0, rt = 0, lf[N * 2], rf[N * 2], lx[N * 2], rx[N * 2];
ll mx[N * 2];
struct Opt { int opt; ll x, y; }; Opt q[N];
ll p[N]; int lp = 0;

inline ll Max(const ll &x, const ll &y) { return x > y ? x : y; }
void build(int &u, int l, int r) {
u = ++tot; if(l >= r) return ;
int mid = (l + r) >> 1; build(lf[u], l, mid), build(rf[u], mid + 1, r);
}
int glx(int u, ll x) {
if(!lf[u]) return x < mx[u];
return x < mx[rf[u]] ? (lx[u] + glx(rf[u], x)) : glx(lf[u], x);
}
int grx(int u, ll x) {
if(!lf[u]) return x < mx[u];
return x < mx[lf[u]] ? (rx[u] + grx(lf[u], x)) : grx(rf[u], x);
}
void pushup(int u) {
mx[u] = Max(mx[lf[u]], mx[rf[u]]);
lx[u] = glx(lf[u], mx[rf[u]]), rx[u] = grx(rf[u], mx[lf[u]]);
}
void mdy(int u, int l, int r, int x, ll w) {
if(l >= r) { mx[u] = w; return ; }
int mid = (l + r) >> 1;
x <= mid ? mdy(lf[u], l, mid, x, w) : mdy(rf[u], mid + 1, r, x, w);
pushup(u);
}
void gmx(int u, int l, int r, int L, int R) {
int mid = (l + r) >> 1;
if(L == l && r == R) {
if(mx[u] < value) return ;
value = mx[u]; if(l >= r) { lca = l; return ; }
gmx(lf[u], l, mid, L, mid), gmx(rf[u], mid + 1, r, mid + 1, R);
return ;
}
if(R <= mid) gmx(lf[u], l, mid, L, R);
else if(L > mid) gmx(rf[u], mid + 1, r, L, R);
else gmx(lf[u], l, mid, L, mid), gmx(rf[u], mid + 1, r, mid + 1, R);
}
int asklx(int u, int l, int r, int x) {
int ss = 0;
if(x >= r) {
ss = glx(u, value); value = Max(value, mx[u]); return ss;
}
int mid = (l + r) >> 1;
if(mid < x) ss += asklx(rf[u], mid + 1, r, x);
ss += asklx(lf[u], l, mid, x); return ss;
}
int askrx(int u, int l, int r, int x) {
int ss = 0;
if(x <= l) {
ss = grx(u, value); value = Max(value, mx[u]); return ss;
}
int mid = (l + r) >> 1;
if(mid >= x) ss += askrx(lf[u], l, mid, x);
ss += askrx(rf[u], mid + 1, r, x); return ss;
}
int dep(int x) {
int ss = -1;
value = 0, ss += asklx(rt, 1, lp, x);
value = 0, ss += askrx(rt, 1, lp, x);
return ss;
}
int LCA(int x, int y) {
value = 0; if(x > y) swap(x, y); gmx(rt, 1, lp, x, y); return lca;
}
int dis(int x, int y) { return dep(x) + dep(y) - 2 * dep(LCA(x, y)); }
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d%lld", &q[i].opt, &q[i].x);
if(q[i].opt ^ 1) scanf("%lld", &q[i].y);
if(!q[i].opt) p[++lp] = q[i].x;
}
sort(p + 1, p + lp + 1), lp = unique(p + 1, p + lp + 1) - p - 1;
build(rt, 1, lp);
for(int i = 1; i <= n; i++) {
q[i].x = lower_bound(p + 1, p + lp + 1, q[i].x) - p;
if(q[i].opt == 0) mdy(rt, 1, lp, q[i].x, q[i].y);
else if(q[i].opt == 1) mdy(rt, 1, lp, q[i].x, 0);
else {
q[i].y = lower_bound(p + 1, p + lp + 1, q[i].y) - p;
printf("%d\n", dis(q[i].x, q[i].y));
}
}
return 0;
}