Problem
Given a Binary Search Tree (where all nodes on the left child branch are less than the node), and all nodes to the right are greater/equal to the node), transform it into a Greater Sum Tree where each node contains sum of it together with all nodes greater than that node. Example diagram is here:
Looking for code review, optimizations and best practices.
public class GreaterSumTree implements Iterable {
private TreeNode root;
public GreaterSumTree(List<Integer> list) {
create(list);
}
private void create (List<Integer> items) {
root = new TreeNode(items.get(0));
final Queue<TreeNode> queue = new LinkedList<TreeNode>();
queue.add(root);
final int half = items.size() / 2;
for (int i = 0; i < half; i++) {
if (items.get(i) != null) {
final TreeNode current = queue.poll();
final int left = 2 * i + 1;
final int right = 2 * i + 2;
if (items.get(left) != null) {
current.left = new TreeNode(items.get(left));
queue.add(current.left);
}
if (right < items.size() && items.get(right) != null) {
current.right = new TreeNode(items.get(right));
queue.add(current.right);
}
}
}
}
public static class TreeNode {
private TreeNode left;
private int item;
private TreeNode right;
TreeNode(int item) {
this.item = item;
}
}
public static class IntObject {
private int sum;
}
/**
* Computes the greater sum, provided the tree is BST.
* If tree is not BST, then results are unpredictable.
*/
public void greaterSumTree() {
if (root == null) {
throw new IllegalArgumentException("root is null");
}
computeSum (root, new IntObject());
}
private void computeSum(TreeNode node, IntObject intObj) {
if (node != null) {
computeSum(node.right, intObj);
int temp = node.item;
node.item = intObj.sum;
intObj.sum = intObj.sum + temp;
computeSum(node.left, intObj);
}
}
/**
* Returns the preorder representation for the given tree.
*
* @return the iterator for preorder traversal
*/
@Override
public Iterator iterator () {
return new PreOrderItr();
}
private class PreOrderItr implements Iterator {
private final Stack<TreeNode> stack;
public PreOrderItr() {
stack = new Stack<TreeNode>();
stack.add(root);
}
@Override
public boolean hasNext() {
return !stack.isEmpty();
}
@Override
public Integer next() {
if (!hasNext()) throw new NoSuchElementException("No more nodes remain to iterate");
final TreeNode node = stack.pop();
if (node.right != null) stack.push(node.right);
if (node.left != null) stack.push(node.left);
return node.item;
}
@Override
public void remove() {
throw new UnsupportedOperationException("Invalid operation for pre-order iterator.");
}
}
}
Here is a test case:
public class GreaterSumTreeTest {
@Test
public void test() {
Integer[] a = {1, 2, 7, 11, 15, 29, 35};
GreaterSumTree greaterSumTree = new GreaterSumTree(Arrays.asList(a));
greaterSumTree.greaterSumTree();
int[] expected = {71, 87, 89, 72, 35, 42, 0};
int[] actual = new int[a.length];
int counter = 0;
Iterator itr = greaterSumTree.iterator();
while (itr.hasNext()) {
actual[counter++] = (Integer) itr.next();
}
assertTrue(Arrays.equals(expected, actual));
}
}
Solution
Compiler Warnings
Iterable is a raw type. References to generic type Iterable should be parameterized
This is very easy to fix. It’s horrible that you haven’t fixed it already. I’m sure that you know what generics is by now so I don’t need to explain it to you. Whenever you see this warning, add generics!
Iterable
–> Iterable<Integer>
Problem solved!
Then you can get rid of the (Integer)
cast next to itr.next();
Why Binary Search Tree?
My first thought, once I understood the problem your solving was: Why on earth do you have the data as a Binary Search Tree? What does the tree have to do with anything?
If this was a question given to me at an interview, I would ask them this. Why a Binary Search Tree? The operations you’re doing is totally unrelated to a tree. You could just as well perform this on a regular int[]
. If there is a good reason for why it has to be a Binary Search Tree then please explain it to me.
Here’s a simple implementation of how to perform this on a regular int[]
, without using a Binary Tree:
public static void main(String[] args) {
int[] input = new int[] { 1, 2, 7, 11, 15, 29, 35, 40 };
int[] expected = new int[] { 139, 137, 130, 119, 104, 75, 40, 0 };
transformToGreaterSum(input);
System.out.println(Arrays.toString(input));
System.out.println(Arrays.toString(expected));
}
private static void transformToGreaterSum(int[] input) {
int sum = 0;
for (int i = input.length - 1; i >= 0; i--) {
int prevSum = sum;
sum += input[i];
input[i] = prevSum;
}
}
Note that this is also O(n) without using any extra space.
Broken Tree
I’d also like to point out that your actual search tree does not match your image. I found out by debugging your code that your actual search tree looks like this:
The fact that it looks like this, and the fact that your algorithm works at all, is just because your input is sorted: {1, 2, 7, 11, 15, 29, 35}
.
Your tree is in fact not a correct Binary Search Tree. A correct tree could look like this:
Yes, that’s still a tree. It just doesn’t have any left nodes (it does have seven nodes left though, I didn’t remove any). A child node that is bigger than its parent should be on the right, which makes this tree right. (I love the English language)
As if that was not enough, your int[] expected
says:
int[] expected = {71, 87, 89, 72, 35, 42, 0};
But your image says:
119, 137, 75, 139, 130, 104, 0, 40
Given your description:
Given a BST, transform it into greater sum tree where each node contains sum of all nodes greater than that node. Diagram is here.
I would expect the node for 11 to get the value 15 + 29 + 35 + 40 = 119
. I don’t see 119 anywhere in your expected or actual result though…
Verifying a Binary Search Tree
As your question contains an incorrect tree and one of your comments below has an incorrect tree, I’d like to provide you with a method to verify a correct tree:
private static void verifyNode(TreeNode node) {
if (node.left != null) {
if (node.left.item >= node.item) {
throw new IllegalStateException("node.left >= node: " + node.left.item + " >= " + node.item);
}
verifyNode(node.left);
}
if (node.right != null) {
if (node.right.item <= node.item) {
throw new IllegalStateException("node.right <= node: " + node.right.item + " <= " + node.item);
}
verifyNode(node.right);
}
}
Just pass it your root
node and it will verify your tree recursively. Note that there are two things this method does not verify, that are also a requirement for a Binary Search Tree:
- There must be no duplicate nodes.
- A unique path exists from the root to every other node.
Arrays and Lists
Speaking of int[]
I wonder why your constructor takes a List<Integer>
instead of an array: int[]
. I don’t think you gain anything by using a list.
The only thing you gain at the moment by using Integer
rather than int
is the possibility for this:
if (items.get(i) != null) {
As your input never contains null
(and I see no reason why it should), I don’t get the point of using that non-null check. I modified your code to use int[]
instead and removed the non-null checks and all worked well still. It even got rid of some indentation steps and felt like a cleanup overall.
Iterate over what?
If you would have your Iterator
as Iterator<TreeNode>
instead of Iterator<Integer>
you could use your iterator in your computeSum
method as well as in your test method to fetch the results. This would make you have to modify your iterator implementation though (but as your tree is not a real binary search tree, it’s quite broken already).
Or you could just ignore the tree structure and use arrays….
Summary
Given how many trees I’ve seen you implement, I really thought your code would be better than this. Your current code doesn’t live up to your requirements, or your requirements are messed up. Additionally, the “raw type” mistake is a mistake one just should not make after having written just a few Java classes.
A minor bug: You get an IndexOutOfBoundsException
for an empty list in GreaterSumTree.create(List<Integer> items)
. You don’t have a comment stating you need to input a list containing at least something. Consider returning IllegalArgumentException
and adding a comment.
/**
* Computes the greater sum, provided the tree is BST.
* If tree is not BST, then results are unpredictable.
*/
public void greaterSumTree() {
if (root == null) {
throw new IllegalArgumentException("root is null");
}
computeSum (root, new IntObject());
}
You should use IllegalStateException
here. IllegalArgumentException
is for when the arguments are bad. What’s bad here is the internal state of your object. Alternatively, remove the check here and put it in the create
function. That way you don’t get bad internal states.
Simon has covered much of the code-style aspects of your question. I agree with what he says in terms of how you present and manage your tree. I especially find the missing Generics to be particularly grating.
Apart from the code style issues, I want to review your algorithm. Additionally, there are some items I believe should be said in addition to Simon’s comments.
Overkill
Why have an iterator at all? It appears to me that the only use of the iterator is in the test case… right? This is a lot of extra work to perform for something that just does an in-order traversal for the BST… right? A method on your class like int[] toArray()
would be a better solution.
Why have the IntObject
‘accumulator’, and especially, why is it public? There are better ways to do this.
What the question is looking for
What you have done is repeated loops to access the sum of each node, and from that you accumulate the values in to an IntObject.
What you should be doing is a form of dynamic programming, where you use the tree itself to accumulate the results…. Think of this:
- if you go to the right-most node (the largest), it will be the sum of itself.
- Go up from there, and you add the value of the just-computed largest, to the current node, and you get the new value.
- then descend the left side, and populate that tree
- each time you come up the recursion, you return the sum of the tree from there.
The algorithm becomes a form of reversed depth-first traversal. The implementation will be simple:
/**
* Computes the greater sum, provided the tree is BST.
* If tree is not BST, then results are unpredictable.
*/
public void greaterSumTree() {
adjustSum (root, 0);
}
private int adjustSum(TreeNode node, int largerSum) {
if (node == null) {
return largerSum;
}
int current = node.item;
node.item = adjustSum(node.right, largerSum);
return adjustSum(node.left, node.item + current);
}
In the above process there is a single scan of the tree from largest back to smallest, with the correct accumulation of sums as input/output parameters. The tree is modified in-place as a for of dynamic programming.
BST
As for the rest of your class, messed with a few things to simplify the testing. For a start, I made the node population dynamic, and it is always a BST, except for the fact that, after a re-sum, the tree is the opposite of a BST (with all items in the reverse order).
As a result, this tree is not really a BST because it allows the structure to be invalidated (by design)
Code Dump
Here’s the full code I used, note the simplified create, and the toArray methods.
import java.util.Arrays;
public class GreaterSumTree {
private TreeNode root;
public GreaterSumTree(final int[] items) {
if (items == null || items.length == 0) {
return;
}
for (int value : items) {
add(value);
}
}
public void add(int value) {
TreeNode node = root;
TreeNode addNode = new TreeNode(value);
while (node != null) {
if (value < node.item) {
if (node.left == null) {
node.left = addNode;
return;
}
node = node.left;
} else {
if (node.right == null) {
node.right = addNode;
return;
}
node = node.right;
}
}
// no nodes to add to. Must be root.
root = addNode;
}
public static class TreeNode {
private TreeNode left;
private int item;
private TreeNode right;
TreeNode(int item) {
this.item = item;
}
}
/**
* Computes the greater sum, provided the tree is BST.
* If tree is not BST, then results are unpredictable.
*/
public void greaterSumTree() {
adjustSum (root, 0);
}
private int adjustSum(TreeNode node, int largerSum) {
if (node == null) {
return largerSum;
}
int current = node.item;
node.item = adjustSum(node.right, largerSum);
return adjustSum(node.left, node.item + current);
}
public int[] toArray() {
IntArray array = new IntArray();
addToArray(root, array);
return array.getArray();
}
private void addToArray(TreeNode node, IntArray array) {
if (node == null) {
return;
}
addToArray(node.left, array);
array.add(node.item);
addToArray(node.right, array);
}
private static final class IntArray {
private int[] data = new int[128];
private int size = 0;
public void add(int value) {
if (size == data.length) {
data = Arrays.copyOf(data, size + (size >> 2) + 16);
}
data[size++] = value;
}
public int[] getArray() {
return Arrays.copyOf(data, size);
}
}
public static void main(String[] args) {
int[][] data = {
{ 0,1,2,3,4,5},
{ 11,2,29,1,7,15,40,35},
{ 4,8,66,4,1,2,6,0,1,0}
};
for (int[] items : data) {
GreaterSumTree gst = new GreaterSumTree(items);
int[] pre = gst.toArray();
gst.greaterSumTree();
int[] post = gst.toArray();
System.out.printf("Processed %s%n initial: %s%n post: %s%n", Arrays.toString(items), Arrays.toString(pre), Arrays.toString(post));
}
}
}