Iterative quicksort in Java

Posted on

Problem

I have this quicksort implementation that relies on heap memory instead of actual stack. It uses two integer stacks for keeping track what subranges are yet to be sorted. When a partition routine is called, it does its job and returns the index of the correct position of the pivot element. Then it looks at two subranges that are divided by the pivot; whichever has length more than one element are pushed to the stacks so that more refined sorting happens on those two subranges.

Code

IterativeQuicksort.java

package net.coderodde.util;

import java.util.Arrays;
import java.util.Objects;
import java.util.Random;

public class IterativeQuicksort {

    private static final class IntStack {

        private int[] storage;
        private int head;
        private int tail;
        private int moduloMask;
        private int size;

        IntStack() {
            storage = new int[8];
            moduloMask = 7;
        }

        void push(int num) {
            if (size == storage.length) {
                doubleCapacity();
            }

            storage[tail] = num;
            tail = (tail + 1) & moduloMask;
            size++;
        }

        int pop() {
            int ret = storage[head];
            head = (head + 1) & moduloMask;
            size--;
            return ret;
        }

        int size() {
            return size;
        }

        @Override
        public String toString() {
            StringBuffer sb = new StringBuffer("[");
            String separator = "";

            for (int i = 0; i < size; ++i) {
                sb.append(separator).append(storage[(head + i) & moduloMask]);
                separator = ", ";
            }

            return sb.append(']').toString();
        }

        private void doubleCapacity() {
            int[] newStorage = new int[storage.length << 1];

            for (int i = 0; i < size; ++i) {
                newStorage[i] = storage[(head + i) & moduloMask];
            }

            head = 0;
            tail = size;
            moduloMask = newStorage.length - 1;
            storage = newStorage;
        }
    }

    public static void sort(int[] array) {
        sort(array, 0, array.length);
    }

    public static void sort(int[] array, int fromIndex, int toIndex) {
        Objects.requireNonNull(array, "The input array is null.");
        rangeCheck(array.length, fromIndex, toIndex);

        IntStack startIndexStack = new IntStack();
        IntStack endIndexStack   = new IntStack();

        startIndexStack.push(fromIndex);
        endIndexStack.push(toIndex);

        while (startIndexStack.size() > 0) {
            int startIndex = startIndexStack.pop();
            int endIndex = endIndexStack.pop();

            int pivotIndex = partition(array, startIndex, endIndex);

            int leftChunkLength = pivotIndex - startIndex;
            int rightChunkLength = endIndex - pivotIndex - 1;

            if (leftChunkLength > 1) {
                startIndexStack.push(startIndex);
                endIndexStack.push(pivotIndex);
            }

            if (rightChunkLength > 1) {
                startIndexStack.push(pivotIndex + 1);
                endIndexStack.push(endIndex);
            }
        }
    }

    private static int partition(int[] array, int startIndex, int endIndex) {
        int pivot = array[startIndex];
        int i = startIndex - 1;

        for (int j = startIndex; j < endIndex; ++j) {
            if (array[j] <= pivot) {
                ++i;

                if (i != j) {
                    swap(array, i, j);
                }
            }
        }

        swap(array, startIndex, i);
        return i;
    }

    private static void swap(int[] array, int i, int j) {
        int tmp = array[i];
        array[i] = array[j];
        array[j] = tmp;
    }

    private static void rangeCheck(int arrayLength,
                                   int fromIndex, 
                                   int toIndex) {
        if (fromIndex > toIndex) {
            throw new IllegalArgumentException("fromIndex (" + fromIndex
                    + ") > toIndex (" + toIndex + ")");
        }

        if (fromIndex < 0) {
            throw new ArrayIndexOutOfBoundsException("fromIndex (" + fromIndex +
                    ") < 0");
        }

        if (toIndex > arrayLength) {
            throw new ArrayIndexOutOfBoundsException("toIndex (" + toIndex +
                    ") > arrayLength (" + arrayLength + ")");
        }
    }

    private IterativeQuicksort() {

    }

    private static final int LENGTH = 5_000_000;
    private static final int FROM_INDEX = 10;
    private static final int TO_INDEX = LENGTH - 10;
    private static final int MAXIMUM = 5_000_000;
    private static final int MINIMUM = -5_000_000;

    public static void main(String[] args) {
        long seed = System.currentTimeMillis();
        Random random = new Random(seed);
        int[] array = getRandomIntArray(LENGTH, MINIMUM, MAXIMUM, random);

        warmup(array.clone());
        benchmark(array.clone());
    }

    private static void warmup(int[] array) {
        perform(array, true);
    }

    private static void benchmark(int[] array) {
        perform(array, false);
    }

    private static void perform(int[] array, boolean warmup) {
        int[] array1 = array.clone();
        int[] array2 = array.clone();

        long startTime = System.currentTimeMillis();
        IterativeQuicksort.sort(array1, FROM_INDEX, TO_INDEX);
        long endTime = System.currentTimeMillis();

        if (!warmup) {
            System.out.println("IterativeQuicksort.sort in " + 
                    (endTime - startTime) + " ms.");
        }

        startTime = System.currentTimeMillis();
        Arrays.sort(array2, FROM_INDEX, TO_INDEX);
        endTime = System.currentTimeMillis();

        if (!warmup) {
            System.out.println("Arrays.sort in " + (endTime - startTime) + 
                    " ms.");
            System.out.println("Algorithms agree: " + Arrays.equals(array1,
                                                                    array2));
        }
    }

    private static int[] getRandomIntArray(int length, 
                                           int minimum,
                                           int maximum, 
                                           Random random) {
        int[] array = new int[length];

        for (int i = 0; i < length; ++i) {
            array[i] = minimum + random.nextInt(maximum - minimum + 1);
        }

        return array;
    }
}

Sample output

IterativeQuicksort.sort in 1487 ms.
Arrays.sort in 1380 ms.
Algorithms agree: true

Critique request

Please point out anything that needs improvement.

Solution

Can use fixed sized stack

Just a quick observation: if you always push the larger range on the stack first and the smaller range on the stack second, you will never need more than 32 stack entries, as java arrays cannot be larger than 231231 in size. So you could remove your IntStack class and instead use a fixed int array of size 32.

Sample rewrite

Here is a rewrite of your function using a fixed stack. I combined both the startIndex and endIndex stacks into a single stack where I push and pop two entries at a time:

public static void sort(int[] array, int fromIndex, int toIndex) {
    int [] indexStack = new int[64];
    int    stackIndex = 0;
    Objects.requireNonNull(array, "The input array is null.");
    rangeCheck(array.length, fromIndex, toIndex);

    indexStack[stackIndex++] = fromIndex;
    indexStack[stackIndex++] = toIndex;

    while (stackIndex > 0) {
        int endIndex   = indexStack[--stackIndex];
        int startIndex = indexStack[--stackIndex];

        int pivotIndex = partition(array, startIndex, endIndex);

        int leftChunkLength  = pivotIndex - startIndex;
        int rightChunkLength = endIndex - pivotIndex - 1;

        // Always push the larger chunk first, followed by the smaller chunk.
        if (leftChunkLength > 1 && leftChunkLength > rightChunkLength) {
            indexStack[stackIndex++] = startIndex;
            indexStack[stackIndex++] = pivotIndex;
        }

        if (rightChunkLength > 1) {
            indexStack[stackIndex++] = pivotIndex + 1;
            indexStack[stackIndex++] = endIndex;
        }

        if (leftChunkLength > 1 && leftChunkLength <= rightChunkLength) {
            indexStack[stackIndex++] = startIndex;
            indexStack[stackIndex++] = pivotIndex;
        }
    }
}

This version was about 10% faster than the original version, at least on my computer:

Original:

IterativeQuicksort.sort in 551 ms.
Arrays.sort in 406 ms.
Algorithms agree: true

Rewrite:

IterativeQuicksort.sort in 496 ms.
Arrays.sort in 405 ms.
Algorithms agree: true

Leave a Reply

Your email address will not be published. Required fields are marked *