1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
| #include <bits/stdc++.h> using namespace std;
#define ls son[0] #define rs son[1]
const int maxn = 1e5 + 5;
template <class _Tp> class splay_tree { private: struct node; typedef node* pos; node buf[maxn]; int buf_cnt, fix_cnt; pos need_fix[maxn]; struct node { _Tp val; int cnt, size; pos son[2]; node() { cnt = size = 0; val = 0; ls = rs = NULL; } }; inline pos new_node(_Tp val, int cnt) { pos res = buf + (++buf_cnt); res -> ls = res -> rs = buf; res -> val = val; res -> cnt = res -> size = cnt; return res; } pos root; public: splay_tree() { buf -> ls = buf -> rs = buf; buf -> cnt = buf -> size = 0; root = buf; } void insert(_Tp val) { root = __insert(val, 1, root); } void insert(_Tp val, int cnt) { root = __insert(val, cnt, root); } void remove(_Tp val) { root = __remove(val, root); } void print() { __print(root); } void debug() { putchar('\n'); printf("root: %d\n", root - buf); for (int i = 0; i <= buf_cnt; i++) { printf("node#%d val: %d cnt: %d size: %d ls: %d rs: %d\n", i, buf[i].val, buf[i].cnt, buf[i].size, buf[i].ls - buf, buf[i].rs - buf); } putchar('\n'); } inline int rank(_Tp val) { root = splay_val(val, root); return rank_min(root); } inline _Tp kth(int k) { root = splay_rank(k, root); return root -> val; } _Tp pre(_Tp val) { pos t = root; _Tp ans; while (t != buf) { if (t -> val < val) { ans = t -> val; t = t -> rs; } else t = t -> ls; } return ans; } _Tp nxt(_Tp val) { pos t = root; _Tp ans; while (t != buf) { if (t -> val > val) { ans = t -> val; t = t -> ls; } else t = t -> rs; } return ans; } private: inline void fix_up(pos x) { x -> size = x -> ls -> size + x -> rs -> size + x -> cnt; } inline pos rotate(pos x, int with) { pos y = x -> son[with]; x -> son[with] = y -> son[with ^ 1]; y -> son[with ^ 1] = x; fix_up(x); fix_up(y); return y; } pos splay_val(_Tp val, pos t) { node header; header.ls = header.rs = buf; pos tmp[2] = {&header, &header}; fix_cnt = 0; while (t -> val != val) { int f1 = (val > t -> val); if (t -> son[f1] == buf) break; if (t -> son[f1] -> val != val) { int f2 = (val > t -> son[f1] -> val); if (f1 == f2) t = rotate(t, f1); if (t -> son[f1] == buf) break; } tmp[f1 ^ 1] -> son[f1] = t; tmp[f1 ^ 1] = t; need_fix[++fix_cnt] = t; t = t -> son[f1]; } tmp[0] -> rs = t -> ls; tmp[1] -> ls = t -> rs; t -> ls = header.rs; t -> rs = header.ls; for (int i = fix_cnt; i >= 1; i--) fix_up(need_fix[i]); fix_up(t); return t; } inline int rank_min(pos t) { return t -> ls -> size + 1; } inline int rank_max(pos t) { return t -> ls -> size + t -> cnt; } pos splay_rank(int k, pos t) { node header; header.ls = header.rs = buf; pos tmp[2] = {&header, &header}; fix_cnt = 0; while (rank_min(t) > k || rank_max(t) < k) {
int f1 = (rank_max(t) < k); if (f1 == 1) k -= rank_max(t);
if (t -> son[f1] == buf) break; if (rank_min(t -> son[f1]) > k || rank_max(t -> son[f1]) < k) { int f2 = (rank_max(t -> son[f1]) < k);
if (f1 == f2) { if (f2 == 1) k -= rank_max(t -> son[f1]); t = rotate(t, f1); } if (t -> son[f1] == buf) break; } tmp[f1 ^ 1] -> son[f1] = t; tmp[f1 ^ 1] = t; need_fix[++fix_cnt] = t; t = t -> son[f1];
} tmp[0] -> rs = t -> ls; tmp[1] -> ls = t -> rs; t -> ls = header.rs; t -> rs = header.ls; for (int i = fix_cnt; i >= 1; i--) fix_up(need_fix[i]); fix_up(t); return t; } pos __insert(_Tp val, int cnt, pos t) { pos p = new_node(val, cnt); if (t == buf) t = p; else { t = splay_val(val, t); if (t -> val == val) { t -> cnt++; t -> size++; buf_cnt--; return t; } int f = (val > t -> val); p -> son[f] = t -> son[f]; p -> son[f ^ 1] = t; t -> son[f] = buf; fix_up(t); t = p; fix_up(t); } return t; } pos __remove(_Tp val, pos t) { if (t != buf) { t = splay_val(val, t); if (val == t -> val) { t -> size--; t -> cnt--; if (t -> cnt == 0) { pos p; if (t -> ls == buf) p = t -> rs; else { p = t -> ls; p = splay_val(val, p); p -> rs = t -> rs; fix_up(p); } t -> ls = t -> rs = buf; t = p; } } } return t; } void __print(pos t) { if (t == buf) return; __print(t -> ls); for (int i = 1; i <= t -> cnt; i++) printf("%d ", t -> val); __print(t -> rs); } };
|