Source (Plain Text)
template<typename K>
class avltree {
private:
struct node {
K key; int h, size;
node *l, *r;
node() { size=0; h=-1; }
node (K _key) { key=_key; size=1; h=0; }
} *root, *NIL;
void clear (node* u) {
if (u == NIL) return;
clear (u->l); clear (u->r);
delete u;
}
void update (node* u) {
u->h = 1 + max (u->r->h, u->l->h);
u->size = 1 + u->r->size + u->l->size;
} node* zag (node* u) {
node *v = u->l;
u->l = v->r;
v->r = u; update (v->r);
return v;
} node* zig (node* u) {
node *v = u->r;
u->r = v->l;
v->l = u; update (v->l);
return v;
}
node* balance (node *u) {
int d = u->l->h - u->r->h;
if (d > 1) {
if (u->l->l->h > u->l->r->h) u = zag (u);
else {
u->l = zig (u->l);
u = zag (u);
}
} else if (d < -1) {
if (u->r->r->h > u->r->l->h) u = zig (u);
else {
u->r = zag (u->r);
u = zig (u);
}
}
update (u);
return u;
}
node* insert (node *u, K key) {
if (u == NIL) {
node *v = new node (key); v->l=v->r=NIL; return v;
} else {
if (key < u->key) u->l = insert (u->l, key);
else if (key > u->key) u->r = insert (u->r, key);
return balance (u);
}
}
node* next (node *u) {
u=u->r;
while (u->l != NIL) u=u->l;
return u;
}
node* removeMin (node *u) {
if (u->l == NIL) { node* v = u->r; delete u; return v; }
else {
u->l = removeMin (u->l);
return balance (u);
}
}
node* remove (node *u, K key) {
if (u == NIL) return NIL;
if (key < u->key) u->l = remove (u->l, key);
else if (key > u->key) u->r = remove (u->r, key);
else {
if (u->l == NIL) { node* v = u->r; delete u; return v; }
if (u->r == NIL) { node* v = u->l; delete u; return v; }
node *v = next (u);
u->key = v->key;
u->r = removeMin (u->r);
}
return balance (u);
}
K select (node *u, int i) {
int r = u->l->size;
if (i == r) return u->key;
else if (i < r) return select (u->l, i);
else return select (u->r, i-r-1);
}
int rank (node *u, K key) {
if (u == NIL) return 0;
if (key == u->key) return u->l->size;
if (key < u->key) return rank (u->l, key);
return u->l->size + 1 + rank (u->r, key);
}
public:
avltree() { NIL=new node(); NIL->l=NIL->r=NIL; root=NIL; }
~avltree() { clear(); delete NIL; }
void clear() { clear (root); root=NIL; }
int size() { return root->size; }
void insert (K key) { root = insert (root, key); }
void remove (K key) { root = remove (root, key); }
bool find (K key) {
node *u = root;
while (u != NIL) {
if (key < u->key) u=u->l;
else if (key > u->key) u=u->r;
else return true;
}
return false;
}
K select (int i) { return select (root, i); }
int rank (K key) { return rank (root, key); }
};