diff options
author | Rich Hickey <richhickey@gmail.com> | 2006-05-22 19:42:40 +0000 |
---|---|---|
committer | Rich Hickey <richhickey@gmail.com> | 2006-05-22 19:42:40 +0000 |
commit | 4770e052cc5df518e2d32632398e46abacb93524 (patch) | |
tree | b86496dd96c13a61163fff2399d7418051bffca9 /src | |
parent | 7e7516821258dd51e411b51a8851c54320f64dc3 (diff) |
added remove
Diffstat (limited to 'src')
-rw-r--r-- | src/org/clojure/runtime/RBSet.java | 258 |
1 files changed, 226 insertions, 32 deletions
diff --git a/src/org/clojure/runtime/RBSet.java b/src/org/clojure/runtime/RBSet.java index 0c81c656..b6079964 100644 --- a/src/org/clojure/runtime/RBSet.java +++ b/src/org/clojure/runtime/RBSet.java @@ -14,6 +14,14 @@ package org.clojure.runtime; import java.util.*; +/** + * Persistent Red Black Tree + * Note that instances of this class are constant values + * i.e. add/remove etc return new values + * <p/> + * See Okasaki, Kahrs, Larsen + */ + public class RBSet{ public final Comparator comp; @@ -38,30 +46,33 @@ public RBSet add(Object key){ return put(key, null); } -public RBSet put(Object key,Object val){ +public RBSet put(Object key, Object val){ Box found = new Box(null); Node t = add(tree, key, val, found); if(t == null) //null == already contains key { - Node foundNode = (Node) found.val; + Node foundNode = (Node) found.val; if(foundNode.val() == val) //note only get same collection on identity of val, not equals() return this; - return new RBSet(comp, replace(tree,key,val), count); + return new RBSet(comp, replace(tree, key, val), count); } return new RBSet(comp, t.blacken(), count + 1); } -/* public RBSet remove(Object key){ - Node t = remove(tree, key); - if(t == null) //null == doesn't contain key - return this; - if(t instanceof Red) - t = blacken(t); - return new RBSet(comp, t, count - 1); + Box found = new Box(null); + Node t = remove(tree, key, found); + if(t == null) + { + if(found.val == null)//null == doesn't contain key + return this; + //empty + return new RBSet(comp); + } + return new RBSet(comp, t.blacken(), count - 1); } -*/ + public NodeIterator iterator(){ return new NodeIterator(tree, true); @@ -87,24 +98,34 @@ public Iterator vals(NodeIterator it){ return new ValIterator(it); } -public Object min(){ +public Object minKey(){ + Node t = min(); + return t!=null?t.key:null; +} + +public Node min(){ Node t = tree; if(t != null) { while(t.left() != null) t = t.left(); } - return t != null ? t.key : null; + return t; } -public Object max(){ +public Object maxKey(){ + Node t = max(); + return t!=null?t.key:null; +} + +public Node max(){ Node t = tree; if(t != null) { while(t.right() != null) t = t.right(); } - return t != null ? t.key : null; + return t; } public int depth(){ @@ -143,7 +164,7 @@ Node add(Node t, Object key, Object val, Box found){ { if(val == null) return new Red(key); - return new RedVal(key,val); + return new RedVal(key, val); } int c = compare(key, t.key); if(c == 0) @@ -159,12 +180,121 @@ Node add(Node t, Object key, Object val, Box found){ return t.addRight(ins); } +Node remove(Node t, Object key, Box found){ + if(t == null) + return null; //not found indicator + int c = compare(key, t.key); + if(c == 0) + { + found.val = t; + return append(t.left(), t.right()); + } + Node del = c < 0 ? remove(t.left(), key, found) : remove(t.right(), key, found); + if(del == null && found.val == null) //not found below + return null; + if(c < 0) + { + if(t.left() instanceof Black) + return balanceLeftDel(t.key, t.val(), del, t.right()); + else + return red(t.key, t.val(), del, t.right()); + } + if(t.right() instanceof Black) + return balanceRightDel(t.key, t.val(), t.left(), del); + return red(t.key, t.val(), t.left(), del); +// return t.removeLeft(del); +// return t.removeRight(del); +} + +static Node append(Node left, Node right){ + if(left == null) + return right; + else if(right == null) + return left; + else if(left instanceof Red) + { + if(right instanceof Red) + { + Node app = append(left.right(), right.left()); + if(app instanceof Red) + return red(app.key, app.val(), + red(left.key, left.val(), left.left(), app.left()), + red(right.key, right.val(), app.right(), right.right())); + else + return red(left.key, left.val(), left.left(), red(right.key, right.val(), app, right.right())); + } + else + return red(left.key, left.val(), left.left(), append(left.right(), right)); + } + else if(right instanceof Red) + return red(right.key, right.val(), append(left, right.left()), right.right()); + else //black/black + { + Node app = append(left.right(), right.left()); + if(app instanceof Red) + return red(app.key, app.val(), + black(left.key, left.val(), left.left(), app.left()), + black(right.key, right.val(), app.right(), right.right())); + else + return balanceLeftDel(left.key, left.val(), left.left(), black(right.key, right.val(), app, right.right())); + } +} + +static Node balanceLeftDel(Object key, Object val, Node del, Node right){ + if(del instanceof Red) + return red(key, val, del.blacken(), right); + else if(right instanceof Black) + return rightBalance(key, val, del, right.redden()); + else if(right instanceof Red && right.left() instanceof Black) + return red(right.left().key, right.left().val(), + black(key, val, del, right.left().left()), + rightBalance(right.key, right.val(), right.left().right(), right.right().redden())); + else + throw new UnsupportedOperationException("Invariant violation"); +} + +static Node balanceRightDel(Object key, Object val, Node left, Node del){ + if(del instanceof Red) + return red(key, val, left, del.blacken()); + else if(left instanceof Black) + return leftBalance(key, val, left.redden(), del); + else if(left instanceof Red && left.right() instanceof Black) + return red(left.right().key, left.right().val(), + leftBalance(left.key, left.val(), left.left().redden(), left.right().left()), + black(key, val, left.right().right(), del)); + else + throw new UnsupportedOperationException("Invariant violation"); +} + +static Node leftBalance(Object key, Object val, Node ins, Node right){ + if(ins instanceof Red && ins.left() instanceof Red) + return red(ins.key, ins.val(), ins.left().blacken(), black(key, val, ins.right(), right)); + else if(ins instanceof Red && ins.right() instanceof Red) + return red(ins.right().key, ins.right().val(), + black(ins.key, ins.val(), ins.left(), ins.right().left()), + black(key, val, ins.right().right(), right)); + else + return black(key, val, ins, right); +} + + +static Node rightBalance(Object key, Object val, Node left, Node ins){ + if(ins instanceof Red && ins.right() instanceof Red) + return red(ins.key, ins.val(), black(key, val, left, ins.left()), ins.right().blacken()); + else if(ins instanceof Red && ins.left() instanceof Red) + return red(ins.left().key, ins.left().val(), + black(key, val, left, ins.left().left()), + black(ins.key, ins.val(), ins.left().right(), ins.right())); + else + return black(key, val, left, ins); +} + Node replace(Node t, Object key, Object val){ int c = compare(key, t.key); return t.replace(t.key, - c==0?val:t.val(), - c<0?replace(t.left(),key,val):t.left(), - c>0?replace(t.right(),key,val):t.right()); + c == 0 ? val : t.val(), + c < 0 ? replace(t.left(), key, val) : t.left(), + c > 0 ? replace(t.right(), key, val) : t.right()); } RBSet(Comparator comp, Node tree, int count){ @@ -225,8 +355,14 @@ static abstract class Node{ abstract Node addRight(Node ins); + abstract Node removeLeft(Node del); + + abstract Node removeRight(Node del); + abstract Node blacken(); + abstract Node redden(); + Node balanceLeft(Node parent){ return black(parent.key, parent.val(), this, parent.right()); } @@ -251,10 +387,22 @@ static class Black extends Node{ return ins.balanceRight(this); } + Node removeLeft(Node del){ + return balanceLeftDel(key, val(), del, right()); + } + + Node removeRight(Node del){ + return balanceRightDel(key, val(), left(), del); + } + Node blacken(){ return this; } + Node redden(){ + return new Red(key); + } + Node replace(Object key, Object val, Node left, Node right){ return black(key, val, left, right); } @@ -262,7 +410,8 @@ static class Black extends Node{ static class BlackVal extends Black{ final Object val; - public BlackVal(Object key,Object val){ + + public BlackVal(Object key, Object val){ super(key); this.val = val; } @@ -270,16 +419,22 @@ static class BlackVal extends Black{ public Object val(){ return val; } + + Node redden(){ + return new RedVal(key, val); + } + } static class BlackBranch extends Black{ final Node left; final Node right; - public BlackBranch(Object key,Node left,Node right){ + + public BlackBranch(Object key, Node left, Node right){ super(key); this.left = left; this.right = right; - } + } public Node left(){ return left; @@ -288,17 +443,29 @@ static class BlackBranch extends Black{ public Node right(){ return right; } + + Node redden(){ + return new RedBranch(key, left, right); + } + } static class BlackBranchVal extends BlackBranch{ final Object val; - public BlackBranchVal(Object key,Object val,Node left,Node right){ + + public BlackBranchVal(Object key, Object val, Node left, Node right){ super(key, left, right); this.val = val; } + public Object val(){ return val; } + + Node redden(){ + return new RedBranchVal(key, val, left, right); + } + } static class Red extends Node{ @@ -314,10 +481,22 @@ static class Red extends Node{ return red(key, val(), left(), ins); } + Node removeLeft(Node del){ + return red(key, val(), del, right()); + } + + Node removeRight(Node del){ + return red(key, val(), left(), del); + } + Node blacken(){ return new Black(key); } + Node redden(){ + throw new UnsupportedOperationException("Invariant violation"); + } + Node replace(Object key, Object val, Node left, Node right){ return red(key, val, left, right); } @@ -325,7 +504,8 @@ static class Red extends Node{ static class RedVal extends Red{ final Object val; - public RedVal(Object key,Object val){ + + public RedVal(Object key, Object val){ super(key); this.val = val; } @@ -342,11 +522,12 @@ static class RedVal extends Red{ static class RedBranch extends Red{ final Node left; final Node right; - public RedBranch(Object key,Node left,Node right){ + + public RedBranch(Object key, Node left, Node right){ super(key); this.left = left; this.right = right; - } + } public Node left(){ return left; @@ -360,7 +541,7 @@ static class RedBranch extends Red{ if(left instanceof Red) return red(key, val(), left.blacken(), black(parent.key, parent.val(), right, parent.right())); else if(right instanceof Red) - return red(right.key, right.val(),black(key, val(), left, right.left()), + return red(right.key, right.val(), black(key, val(), left, right.left()), black(parent.key, parent.val(), right.right(), parent.right())); else return super.balanceLeft(parent); @@ -384,10 +565,12 @@ static class RedBranch extends Red{ static class RedBranchVal extends RedBranch{ final Object val; - public RedBranchVal(Object key,Object val,Node left,Node right){ + + public RedBranchVal(Object key, Object val, Node left, Node right){ super(key, left, right); this.val = val; } + public Object val(){ return val; } @@ -431,6 +614,7 @@ static public class NodeIterator implements Iterator{ static class KeyIterator implements Iterator{ NodeIterator it; + KeyIterator(NodeIterator it){ this.it = it; } @@ -440,7 +624,7 @@ static class KeyIterator implements Iterator{ } public Object next(){ - return ((Node)it.next()).key; + return ((Node) it.next()).key; } public void remove(){ @@ -450,6 +634,7 @@ static class KeyIterator implements Iterator{ static class ValIterator implements Iterator{ NodeIterator it; + ValIterator(NodeIterator it){ this.it = it; } @@ -459,7 +644,7 @@ static class ValIterator implements Iterator{ } public Object next(){ - return ((Node)it.next()).val(); + return ((Node) it.next()).val(); } public void remove(){ @@ -487,9 +672,9 @@ static public void main(String args[]){ for(int i = 0; i < ints.length; i++) { Integer anInt = ints[i]; - set = set.put(anInt,anInt); + set = set.put(anInt, anInt); } - System.out.println("count = " + set.count + ", min: " + set.min() + ", max: " + set.max() + System.out.println("count = " + set.count + ", min: " + set.minKey() + ", max: " + set.maxKey() + ", depth: " + set.depth()); Iterator it = set.keys(); while(it.hasNext()) @@ -500,5 +685,14 @@ static public void main(String args[]){ else if(n < 2000) System.out.print(o.toString() + ","); } + it = set.keys(); + while(it.hasNext()) + { + Object o = it.next(); + set = set.remove(o); + } + System.out.println(); + System.out.println("count = " + set.count + ", min: " + set.minKey() + ", max: " + set.maxKey() + + ", depth: " + set.depth()); } } |