Problem
I had a challenge a while back from a friend to try and produce the fastest possible program to compute A(3,16)
, where A
is the Ackermann function.
He wrote it in Java, and it took ~4.4 seconds. The below (very ugly) Haskell program that I wrote took ~1.7 seconds:
{-# LANGUAGE BangPatterns #-}
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
-- Main method:
main :: IO ()
main = print $ fst $ ack (I 3 16) Map.empty
-- Definitions:
data I = I {-# UNPACK #-} !Int {-# UNPACK #-} !Int
deriving (Eq, Ord)
ack :: I -> Map I Int -> (Int, Map I Int)
ack (I 0 n) r = (n + 1, r)
ack i@(I m 0) r1 = maybe bak fun (Map.lookup i r1)
where
(!val, !r2) = ack (I (m-1) 1) r1
bak = (val, Map.insert i val r2)
fun v = (v, r1)
ack i@(I m n) r1 = maybe bak2 fun2 (Map.lookup call2 r2)
where
call1 = (I m (n - 1))
(!val1, !re1) = ack call1 r1
bak1 = (val1, Map.insert call1 val1 re1)
fun1 v = (v, r1)
(!v1, !r2) = maybe bak1 fun1 (Map.lookup call1 r1)
call2 = (I (m - 1) v1)
(!val2, !re2) = ack call2 r2
bak2 = (val2, Map.insert call2 val2 re2)
fun2 v = (v, r2)
This is used on Int
s specifically because A(3,16)
is 524285, which within Int
s range.
How can I make this function more memory efficient, and faster for greater Ackermann calls?
Also, the syntax is disgusting. How can I make this more readable (maybe using monads?) without having horrible performance consequences?
Solution
Using Data.Memocombinators
Have a look at Data.Memocombinators module. It offers combinators for memoizing functions which is essentially what you are doing with the Map
.
Here is the example from the documentation on how to use it to create a memoizing fibonacci function:
import Data.MemoCombinators
fib = Memo.integral fib'
where
fib' 0 = 0
fib' 1 = 1
fib' x = fib (x-1) + fib (x-2)
^^^ ^^^
Things to note:
fib
is the memoized version offib'
which is just a helper functionfib'
callsfib
for any recursive calls
The other thing which will help is to keep in mind that the type Memo a
represents a a combinator which is able to memoize a function whose argument type is a
.
So:
integral
is a combinator which can memoize functions taking anInt
pair integral integral
is a combinator which memoize functions taking an(Int,Int)
And thus to memoize the Ackerman function:
ack = (pair integral integral) ack'
where ack' (0,n) = n+1
ack' (m,0) = ack (m-1,1)
ack' (m,n) = ack (m-1, ack (m, n-1))
Again, note how ack
is defined as the memoized version of ack'
and ack'
calls ack
for recursive cases.
Using Array memoization
Faster results can be obtained by using arrays to memoize the functions ack(1,.)
, ack(2,.)
and ack(3,.)
:
import Data.Array
import Data.Ix
-- memoize the function f using arrays
arrayMemo bnds f = g
where g i = arr ! i
arr = array bnds [ (i,f arr i) | i <- range bnds ]
a0 n = n+1
a1 = arrayMemo (0,530000) f
where f arr 0 = a0 1
f arr n = a0 (arr ! (n-1))
a2 = arrayMemo (0,270000) f
where f arr 0 = a1 1
f arr n = a1 (arr ! (n-1))
a3 = arrayMemo (0,16) f
where f arr 0 = a2 1
f arr n = a2 (arr ! (n-1))
The only drawback is that explicit bounds have to be determined for each level function. This approach computes a3 16
in about half a second.