summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRich Hickey <richhickey@gmail.com>2006-05-22 19:42:40 +0000
committerRich Hickey <richhickey@gmail.com>2006-05-22 19:42:40 +0000
commit4770e052cc5df518e2d32632398e46abacb93524 (patch)
treeb86496dd96c13a61163fff2399d7418051bffca9 /src
parent7e7516821258dd51e411b51a8851c54320f64dc3 (diff)
added remove
Diffstat (limited to 'src')
-rw-r--r--src/org/clojure/runtime/RBSet.java258
1 files changed, 226 insertions, 32 deletions
diff --git a/src/org/clojure/runtime/RBSet.java b/src/org/clojure/runtime/RBSet.java
index 0c81c656..b6079964 100644
--- a/src/org/clojure/runtime/RBSet.java
+++ b/src/org/clojure/runtime/RBSet.java
@@ -14,6 +14,14 @@ package org.clojure.runtime;
import java.util.*;
+/**
+ * Persistent Red Black Tree
+ * Note that instances of this class are constant values
+ * i.e. add/remove etc return new values
+ * <p/>
+ * See Okasaki, Kahrs, Larsen
+ */
+
public class RBSet{
public final Comparator comp;
@@ -38,30 +46,33 @@ public RBSet add(Object key){
return put(key, null);
}
-public RBSet put(Object key,Object val){
+public RBSet put(Object key, Object val){
Box found = new Box(null);
Node t = add(tree, key, val, found);
if(t == null) //null == already contains key
{
- Node foundNode = (Node) found.val;
+ Node foundNode = (Node) found.val;
if(foundNode.val() == val) //note only get same collection on identity of val, not equals()
return this;
- return new RBSet(comp, replace(tree,key,val), count);
+ return new RBSet(comp, replace(tree, key, val), count);
}
return new RBSet(comp, t.blacken(), count + 1);
}
-/*
public RBSet remove(Object key){
- Node t = remove(tree, key);
- if(t == null) //null == doesn't contain key
- return this;
- if(t instanceof Red)
- t = blacken(t);
- return new RBSet(comp, t, count - 1);
+ Box found = new Box(null);
+ Node t = remove(tree, key, found);
+ if(t == null)
+ {
+ if(found.val == null)//null == doesn't contain key
+ return this;
+ //empty
+ return new RBSet(comp);
+ }
+ return new RBSet(comp, t.blacken(), count - 1);
}
-*/
+
public NodeIterator iterator(){
return new NodeIterator(tree, true);
@@ -87,24 +98,34 @@ public Iterator vals(NodeIterator it){
return new ValIterator(it);
}
-public Object min(){
+public Object minKey(){
+ Node t = min();
+ return t!=null?t.key:null;
+}
+
+public Node min(){
Node t = tree;
if(t != null)
{
while(t.left() != null)
t = t.left();
}
- return t != null ? t.key : null;
+ return t;
}
-public Object max(){
+public Object maxKey(){
+ Node t = max();
+ return t!=null?t.key:null;
+}
+
+public Node max(){
Node t = tree;
if(t != null)
{
while(t.right() != null)
t = t.right();
}
- return t != null ? t.key : null;
+ return t;
}
public int depth(){
@@ -143,7 +164,7 @@ Node add(Node t, Object key, Object val, Box found){
{
if(val == null)
return new Red(key);
- return new RedVal(key,val);
+ return new RedVal(key, val);
}
int c = compare(key, t.key);
if(c == 0)
@@ -159,12 +180,121 @@ Node add(Node t, Object key, Object val, Box found){
return t.addRight(ins);
}
+Node remove(Node t, Object key, Box found){
+ if(t == null)
+ return null; //not found indicator
+ int c = compare(key, t.key);
+ if(c == 0)
+ {
+ found.val = t;
+ return append(t.left(), t.right());
+ }
+ Node del = c < 0 ? remove(t.left(), key, found) : remove(t.right(), key, found);
+ if(del == null && found.val == null) //not found below
+ return null;
+ if(c < 0)
+ {
+ if(t.left() instanceof Black)
+ return balanceLeftDel(t.key, t.val(), del, t.right());
+ else
+ return red(t.key, t.val(), del, t.right());
+ }
+ if(t.right() instanceof Black)
+ return balanceRightDel(t.key, t.val(), t.left(), del);
+ return red(t.key, t.val(), t.left(), del);
+// return t.removeLeft(del);
+// return t.removeRight(del);
+}
+
+static Node append(Node left, Node right){
+ if(left == null)
+ return right;
+ else if(right == null)
+ return left;
+ else if(left instanceof Red)
+ {
+ if(right instanceof Red)
+ {
+ Node app = append(left.right(), right.left());
+ if(app instanceof Red)
+ return red(app.key, app.val(),
+ red(left.key, left.val(), left.left(), app.left()),
+ red(right.key, right.val(), app.right(), right.right()));
+ else
+ return red(left.key, left.val(), left.left(), red(right.key, right.val(), app, right.right()));
+ }
+ else
+ return red(left.key, left.val(), left.left(), append(left.right(), right));
+ }
+ else if(right instanceof Red)
+ return red(right.key, right.val(), append(left, right.left()), right.right());
+ else //black/black
+ {
+ Node app = append(left.right(), right.left());
+ if(app instanceof Red)
+ return red(app.key, app.val(),
+ black(left.key, left.val(), left.left(), app.left()),
+ black(right.key, right.val(), app.right(), right.right()));
+ else
+ return balanceLeftDel(left.key, left.val(), left.left(), black(right.key, right.val(), app, right.right()));
+ }
+}
+
+static Node balanceLeftDel(Object key, Object val, Node del, Node right){
+ if(del instanceof Red)
+ return red(key, val, del.blacken(), right);
+ else if(right instanceof Black)
+ return rightBalance(key, val, del, right.redden());
+ else if(right instanceof Red && right.left() instanceof Black)
+ return red(right.left().key, right.left().val(),
+ black(key, val, del, right.left().left()),
+ rightBalance(right.key, right.val(), right.left().right(), right.right().redden()));
+ else
+ throw new UnsupportedOperationException("Invariant violation");
+}
+
+static Node balanceRightDel(Object key, Object val, Node left, Node del){
+ if(del instanceof Red)
+ return red(key, val, left, del.blacken());
+ else if(left instanceof Black)
+ return leftBalance(key, val, left.redden(), del);
+ else if(left instanceof Red && left.right() instanceof Black)
+ return red(left.right().key, left.right().val(),
+ leftBalance(left.key, left.val(), left.left().redden(), left.right().left()),
+ black(key, val, left.right().right(), del));
+ else
+ throw new UnsupportedOperationException("Invariant violation");
+}
+
+static Node leftBalance(Object key, Object val, Node ins, Node right){
+ if(ins instanceof Red && ins.left() instanceof Red)
+ return red(ins.key, ins.val(), ins.left().blacken(), black(key, val, ins.right(), right));
+ else if(ins instanceof Red && ins.right() instanceof Red)
+ return red(ins.right().key, ins.right().val(),
+ black(ins.key, ins.val(), ins.left(), ins.right().left()),
+ black(key, val, ins.right().right(), right));
+ else
+ return black(key, val, ins, right);
+}
+
+
+static Node rightBalance(Object key, Object val, Node left, Node ins){
+ if(ins instanceof Red && ins.right() instanceof Red)
+ return red(ins.key, ins.val(), black(key, val, left, ins.left()), ins.right().blacken());
+ else if(ins instanceof Red && ins.left() instanceof Red)
+ return red(ins.left().key, ins.left().val(),
+ black(key, val, left, ins.left().left()),
+ black(ins.key, ins.val(), ins.left().right(), ins.right()));
+ else
+ return black(key, val, left, ins);
+}
+
Node replace(Node t, Object key, Object val){
int c = compare(key, t.key);
return t.replace(t.key,
- c==0?val:t.val(),
- c<0?replace(t.left(),key,val):t.left(),
- c>0?replace(t.right(),key,val):t.right());
+ c == 0 ? val : t.val(),
+ c < 0 ? replace(t.left(), key, val) : t.left(),
+ c > 0 ? replace(t.right(), key, val) : t.right());
}
RBSet(Comparator comp, Node tree, int count){
@@ -225,8 +355,14 @@ static abstract class Node{
abstract Node addRight(Node ins);
+ abstract Node removeLeft(Node del);
+
+ abstract Node removeRight(Node del);
+
abstract Node blacken();
+ abstract Node redden();
+
Node balanceLeft(Node parent){
return black(parent.key, parent.val(), this, parent.right());
}
@@ -251,10 +387,22 @@ static class Black extends Node{
return ins.balanceRight(this);
}
+ Node removeLeft(Node del){
+ return balanceLeftDel(key, val(), del, right());
+ }
+
+ Node removeRight(Node del){
+ return balanceRightDel(key, val(), left(), del);
+ }
+
Node blacken(){
return this;
}
+ Node redden(){
+ return new Red(key);
+ }
+
Node replace(Object key, Object val, Node left, Node right){
return black(key, val, left, right);
}
@@ -262,7 +410,8 @@ static class Black extends Node{
static class BlackVal extends Black{
final Object val;
- public BlackVal(Object key,Object val){
+
+ public BlackVal(Object key, Object val){
super(key);
this.val = val;
}
@@ -270,16 +419,22 @@ static class BlackVal extends Black{
public Object val(){
return val;
}
+
+ Node redden(){
+ return new RedVal(key, val);
+ }
+
}
static class BlackBranch extends Black{
final Node left;
final Node right;
- public BlackBranch(Object key,Node left,Node right){
+
+ public BlackBranch(Object key, Node left, Node right){
super(key);
this.left = left;
this.right = right;
- }
+ }
public Node left(){
return left;
@@ -288,17 +443,29 @@ static class BlackBranch extends Black{
public Node right(){
return right;
}
+
+ Node redden(){
+ return new RedBranch(key, left, right);
+ }
+
}
static class BlackBranchVal extends BlackBranch{
final Object val;
- public BlackBranchVal(Object key,Object val,Node left,Node right){
+
+ public BlackBranchVal(Object key, Object val, Node left, Node right){
super(key, left, right);
this.val = val;
}
+
public Object val(){
return val;
}
+
+ Node redden(){
+ return new RedBranchVal(key, val, left, right);
+ }
+
}
static class Red extends Node{
@@ -314,10 +481,22 @@ static class Red extends Node{
return red(key, val(), left(), ins);
}
+ Node removeLeft(Node del){
+ return red(key, val(), del, right());
+ }
+
+ Node removeRight(Node del){
+ return red(key, val(), left(), del);
+ }
+
Node blacken(){
return new Black(key);
}
+ Node redden(){
+ throw new UnsupportedOperationException("Invariant violation");
+ }
+
Node replace(Object key, Object val, Node left, Node right){
return red(key, val, left, right);
}
@@ -325,7 +504,8 @@ static class Red extends Node{
static class RedVal extends Red{
final Object val;
- public RedVal(Object key,Object val){
+
+ public RedVal(Object key, Object val){
super(key);
this.val = val;
}
@@ -342,11 +522,12 @@ static class RedVal extends Red{
static class RedBranch extends Red{
final Node left;
final Node right;
- public RedBranch(Object key,Node left,Node right){
+
+ public RedBranch(Object key, Node left, Node right){
super(key);
this.left = left;
this.right = right;
- }
+ }
public Node left(){
return left;
@@ -360,7 +541,7 @@ static class RedBranch extends Red{
if(left instanceof Red)
return red(key, val(), left.blacken(), black(parent.key, parent.val(), right, parent.right()));
else if(right instanceof Red)
- return red(right.key, right.val(),black(key, val(), left, right.left()),
+ return red(right.key, right.val(), black(key, val(), left, right.left()),
black(parent.key, parent.val(), right.right(), parent.right()));
else
return super.balanceLeft(parent);
@@ -384,10 +565,12 @@ static class RedBranch extends Red{
static class RedBranchVal extends RedBranch{
final Object val;
- public RedBranchVal(Object key,Object val,Node left,Node right){
+
+ public RedBranchVal(Object key, Object val, Node left, Node right){
super(key, left, right);
this.val = val;
}
+
public Object val(){
return val;
}
@@ -431,6 +614,7 @@ static public class NodeIterator implements Iterator{
static class KeyIterator implements Iterator{
NodeIterator it;
+
KeyIterator(NodeIterator it){
this.it = it;
}
@@ -440,7 +624,7 @@ static class KeyIterator implements Iterator{
}
public Object next(){
- return ((Node)it.next()).key;
+ return ((Node) it.next()).key;
}
public void remove(){
@@ -450,6 +634,7 @@ static class KeyIterator implements Iterator{
static class ValIterator implements Iterator{
NodeIterator it;
+
ValIterator(NodeIterator it){
this.it = it;
}
@@ -459,7 +644,7 @@ static class ValIterator implements Iterator{
}
public Object next(){
- return ((Node)it.next()).val();
+ return ((Node) it.next()).val();
}
public void remove(){
@@ -487,9 +672,9 @@ static public void main(String args[]){
for(int i = 0; i < ints.length; i++)
{
Integer anInt = ints[i];
- set = set.put(anInt,anInt);
+ set = set.put(anInt, anInt);
}
- System.out.println("count = " + set.count + ", min: " + set.min() + ", max: " + set.max()
+ System.out.println("count = " + set.count + ", min: " + set.minKey() + ", max: " + set.maxKey()
+ ", depth: " + set.depth());
Iterator it = set.keys();
while(it.hasNext())
@@ -500,5 +685,14 @@ static public void main(String args[]){
else if(n < 2000)
System.out.print(o.toString() + ",");
}
+ it = set.keys();
+ while(it.hasNext())
+ {
+ Object o = it.next();
+ set = set.remove(o);
+ }
+ System.out.println();
+ System.out.println("count = " + set.count + ", min: " + set.minKey() + ", max: " + set.maxKey()
+ + ", depth: " + set.depth());
}
}