背景知识
1. 二叉搜索树BST
2. 堆Heap
操作
- 插入
- 删除
- 找前驱/后继
- 找最大/最小
- 求某个值的排名
- 求排名是k的数是哪个
- 找比某个数小的最大值
- 找比某个数大的最小值
结构定义
Node {int l, r;// key指二叉搜索树中的值,val指堆中的值(随机值),两个同时满足要求int key, val;}
这样就能唯一确定一个二叉搜索树。
插入
- 直接插入至叶节点,赋随机值val
- 旋转操作(左旋zag/右旋zig),类似于堆的pushup
删除
- 找到目标节点
- 通过不断左旋或右旋降低节点高度直至叶节点
- 删除目标节点
模板
// 省略IOpublic class Main {static IntReader in;static FastWriter out;static String INPUT = "";static class Node {int l, r;int key, val;int cnt, size;}static final int INF = (int)(1e8), N = 100010;static Node[] tr = new Node[N];static int n, idx;static Random random = new Random();static void solve() {n = ni();int root = build();for (int i = 1; i <= n; i++) {int op = ni(), x = ni();if (op == 1)root = insert(root, x);else if (op == 2)root = delete(root, x);else if (op == 3)out.println(getRank(root, x) - 1);else if (op == 4)out.println(getKey(root, x + 1));else if (op == 5)out.println(getPrev(root, x));else if (op == 6)out.println(getNext(root, x));}}static int getNext(int u, int x) {if (u == 0) return INF;if (tr[u].key <= x)return getNext(tr[u].r, x);elsereturn Math.min(tr[u].key, getNext(tr[u].l, x));}static int getPrev(int u, int x) {if (u == 0) return -INF;if (tr[u].key >= x)return getPrev(tr[u].l, x);elsereturn Math.max(tr[u].key, getPrev(tr[u].r, x));}static int getKey(int u, int rank) {if (u == 0) return 0;if (tr[tr[u].l].size >= rank)return getKey(tr[u].l, rank);else if (tr[tr[u].l].size + tr[u].cnt >= rank)return tr[u].key;elsereturn getKey(tr[u].r, rank - tr[u].cnt - tr[tr[u].l].size);}static int getRank(int u, int key) {if (u == 0) return 0;if (tr[u].key == key) {return tr[tr[u].l].size + 1;} else if (tr[u].key > key) {return getRank(tr[u].l, key);} else {return tr[tr[u].l].size + tr[u].cnt + getRank(tr[u].r, key);}}static int delete(int u, int key) {if (u == 0) {return 0;}if (tr[u].key == key) {if (tr[u].cnt > 1)tr[u].cnt--;else if (tr[u].l != 0 || tr[u].r != 0) {if (tr[u].r == 0 || tr[u].l != 0 && tr[tr[u].l].val > tr[tr[u].r].val) {u = zig(u);tr[u].r = delete(tr[u].r, key);} else {u = zag(u);tr[u].l = delete(tr[u].l, key);}} else return 0;} else if (tr[u].key > key) {tr[u].l = delete(tr[u].l, key);} else {tr[u].r = delete(tr[u].r, key);}pushup(u);return u;}static int insert(int u, int key) {if (u == 0) {int p = createNode(key);return p;}if (tr[u].key == key) {tr[u].cnt++;} else if (tr[u].key > key) {tr[u].l = insert(tr[u].l, key);if (tr[u].val < tr[tr[u].l].val)u = zig(u);} else {tr[u].r = insert(tr[u].r, key);if (tr[u].val < tr[tr[u].r].val)u = zag(u);}pushup(u);return u;}static void pushup(int u) {tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt;}static int zag(int u) {int right = tr[u].r;tr[u].r = tr[right].l;tr[right].l = u;pushup(u);pushup(right);return right;}static int zig(int u) {int left = tr[u].l;tr[u].l = tr[left].r;tr[left].r = u;pushup(u);pushup(left);return left;}static int build() {tr[0] = new Node();int root = createNode(-INF), right = createNode(INF);tr[root].r = right;pushup(root);if (tr[root].val < tr[right].val)root = zag(root);return root;}static int createNode(int x) {++idx;tr[idx] = new Node();tr[idx].key = x;tr[idx].val = random.nextInt(2 * INF) + 1;tr[idx].size = tr[idx].cnt = 1;return idx;}public static void main(String[] args) throws Exception {in = INPUT.isEmpty() ? new IntReader(System.in) : new IntReader(new ByteArrayInputStream(INPUT.getBytes()));out = new FastWriter(System.out);solve();out.flush();}}
