Longest Collatz sequence using dynamic programming in Java

Posted on

Problem

I’m not getting where else should I optimize it. Is there any way to optimize this code any further? Just give me hints.

import java.io.*;
import java.util.*;

public class Solution {
    static int[] countarray = new int[5000000]; //for memoization
    static int count; //for calculating length
    public static void main(String[] args){
        countarray[0]=0;
        Scanner scan = new Scanner(System.in);
        int t = scan.nextInt(); //number of inputs
        while(t > 0){
            int num = scan.nextInt(); // number upto which we have to find maximum length of the sequence
            int max = 0; 
            int result = 1; 
            while(num >= 1){
                    count = 0;
                    if(countarray[num-1] != 0)
                        count = countarray[num-1];
                    else
                        counter(num);
                    if(count > max){
                        max = count;
                        result = num;
                    }
                    num--;
            }
            t--;
            System.out.println(result);    
        }

    }
    public static int counter(int temp){
        if(temp <= 1)
            return count;
        if(temp < 5000000 && countarray[temp-1] != 0){
            count += countarray[temp-1];
            return count;
        }
        if(temp % 2 == 0)
            counter(temp/2);
        else
            counter((3 * temp) + 1);
        count++;
        if(temp < 5000000){
            countarray[temp-1] = count;
        }
        return count;
    }
}

Solution

You memoized once, but you needed to do it twice

Suppose you started with the entire countarray filled in for free. Your program would still time out because you have 104104 test cases, and for each test case, you are searching up to 51065106 array entries for a total of 5101051010 operations. In other words, your program runs in O(TN)O(TN) time, which is too slow.

Think about what happens when the first test case is 5000000 and the second test case is 4999999. When you found the answer for the first test case, you should have also found the answer for the second test case. But you
didn’t memoize that anywhere. If you just did a pass where you found the answer for each n once and saved the answers in a second memoization array, then each test case would take O(1)O(1) time. Then your total time would be O(T+N)O(T+N), which is O(N)O(N) to build the answer array and O(T)O(T) to handle TT test cases.

Be careful of overflow

When computing the collatz sequence, it is easy to overflow a 32-bit integer. I would recommend using a long instead of an int. You might be overflowing and not knowing it because the if (temp <= 1) check may be cutting the sequence short.

I guess you only posted this to get advice on performance, but I will nonetheless give you a short code review because your code is very sloppy (not the actual logic, but cleanliness / readability).

The whole block of code while (num >= 1) should be made a function called maybe solve instead of just being some code within the main method.

Java naming convention would use instead countArray.

I would personally have kept the natural index as the key in countarray, meaning countarray[num] is the count for num instead of having to call countarray[num - 1] everywhere. Of course, countarray would be of size num + 1 and countarray[0] is meaningless.

Try to label all final member variables as such. So here countarray is final. Most new languages have that all variables are immutable (aka final) by default, unless otherwise specified. It’s easier to reason about code when you who most things can’t change.

Java has the syntax 5_000_000. I actually had to count the zeros because I was not sure if it was 500,000 or 5,000,000.

You should have defined something like private static final int CACHE_SIZE = 5_000_000 and use that throughout your code. Now you have 5000000 hard coded everywhere which makes it awkward if you want to modify it. Also maybe in one place you could mistakenly write 500000.

I don’t like too much that count is static (of course countarray should be static). But I guess you had no choice since you used a functional style instead of an object-oriented style so you only have static methods. But you can look what others did submitting this very same problem for review on this site (see “Related” on the right of this page). Functional style is probably better for this task, but it’s not quite the Java way. That would have been more appropriate for example in Kotlin.


On a different note, related to some comments above, I did some search and it does seem that now the JVM can sometimes inline some recursive. I was glad to see that since it can make big performance improvements. Nevertheless, you should keep in mind that it probably won’t inline all recursive functions and you might sometimes performance penalty using recursion.

Also, did you figure out how to get better performance? Many people posted reviews for the exact same problem (see “Related” on the right of this page) and they have pretty much the same solution as you and no one seemed to have performance problems.

EDIT: see JS1’s answer for the biggest efficiency gain. I’ll leave this answer here for some minor improvements in case you want to squeeze out more efficiency at the cost of readability.

  • remove the recursion.
  • After an increase step (number was odd), you already know it’s even (since 3*odd is still odd and +1 makes it even). So you can skip some checks.
  • Use bitwise operators for slightly faster even check and to divide by 2.

main algorithm could look something like this:

for(int i = 2; i < maxStoredIndex; i++){
    int count = 0;
    int n = i;
    while(n>=i){ //n not in previously calculated results
        if(n&1 == 0){ // n even
            n >>= 1; // n /= 2
            count++;
        } else {
            n = (3*n + 1)/2;
            count+=2;
        }
   }
   count+= previously calculated result; //since n < i we know this exists
   store result for i
}

I didn’t check how much faster this would be, and you still need to fill in the parts like storing/fetching previous results. But this should give you the main idea on how to speed up filling in the cashe.

Leave a Reply

Your email address will not be published. Required fields are marked *