Project Euler 92: Square digit chains

Posted on

Problem

A number chain is created by continuously adding the square of the
digits in a number to form a new number until it has been seen before.

For example,

  • 44 → 32 → 13 → 10 → 1 → 1
  • 85 → 89 → 145 → 42 → 20 → 4 → 16 → 37 → 58 → 89

Therefore any chain that arrives at 1 or 89 will become stuck in an
endless loop. What is most amazing is that EVERY starting number will
eventually arrive at 1 or 89.

How many starting numbers below ten million will arrive at 89?

To tackle this problem I used the following procedure.

  • set array[1] to 1
  • set array[89] to 89
  • loop through all starting numbers from 2 to 10,000,000
  • for each starting number, calculate sum of squares until a previously calculated number is found
  • mark off all numbers in the chain you have just created as either 89 or 1
  • count the number of starting numbers that now have 89 as their array value

However even after using this relative clever strategy, my code still takes half a minute to compute. Is there something wrong I have done in my logic, or are there any other obvious improvements? I also think this strategy uses a large chunk of memory.

def square_digits(num):
    # Squares the digits of a number, eg 44=4^2+4^2=32
    total = 0
    while num:
        total += (num % 10) ** 2
        num //= 10
    return total


def single_chain(num, square_dict):
    # Evaluates a single chain untill 1 or 89
    temp_nums = []
    while True:
        if num in [1, 89]:
            break
        try:
            # If we hit a previously calculated value, break
            num = square_dict[num]
            break
        except: 
            temp_nums.append(num)
            num = square_digits(num)
    for i in temp_nums:
        # Backtrack through the chain saving each number
        square_dict[i] = num
    return num == 1, square_dict


def pre_calculation(limit):
    # Precalculates the first values
    square_dict = dict()
    for i in range(1,limit+1):
        num = i
        while num not in [1,89]:
            num = square_digits(num)
        if num == 1:
            square_dict[i] = 1
        else:
            square_dict[i] = 89
    return square_dict


def square_chains(limit, square_dict):
    # Finds the number of chains ending in 1 and 89
    count_1 = 0
    count_89 = 0
    for i in range(1, limit):
        boolean, square_dict = single_chain(i, square_dict)
        if boolean:
            count_1 += 1
        else:
            count_89 += 1
    print "Chains ending in 1: ", count_1
    print "Chains ending in 89: ", count_89


if __name__ == '__main__':
    square_dict = pre_calculation(9*9*7)
    square_chains(10**7,square_dict)

Solution

I already posted the incremental improvement answer. Here’s the home run answer.

Consider the number 4,666,777. This number happens to chain into 89. That takes some amount of work to figure out. But eventually we get there. What does this tell us? Since we’re only interested in the sum of the squares of the digits, the actual ordering of the digits is irrelevant. That is… once we know that 4,666,777 is valid, we also know that 6,466,777 is valid, and 7,664,776 is valid, and … All 140 unique permutations of the digits 4666777 are things we want to count. The key is: once we’re done with 4666777, we do not even need to consider the other 139!

There are only 11,440 unique digit combinations from 1 to 10,000,000. Any solution checking all of them is thus doing ~900x more work than necessary. We can use itertools.combinations_with_replacement to get the unique digit combinations, and then use itertools.groupby to help determine how many such combinations there are.

Still with the memoized my_square_chain:

def euler89():
    count_89 = 0 
    fact7 = fact(7)
    digits = range(10)

    for num in itertools.combinations_with_replacement(digits, 7): 
        cur = sum(d**2 for d in num)
        if cur > 0 and my_square_chain(cur) == 89: 
            count = fact7
            for _, g in itertools.groupby(num):
                count /= fact(len(list(g)))
            count_89 += count
    return count_89

This runs in 0.120s on my box. A performance improvement of 265x from my incremental changes, and 381x from the original solution.

This won’t affect performance, but your return choice for single_chain is odd. Since everything in python is a reference, you don’t need to return square_dict. It’s unnecessary, you could just return num. Then on the call side:

if single_chain(i, square_dict) == 1:
    count_1 += 1
else:
    count_89 += 1

reads a bit better. Your pre_calculation doesn’t actually do anything useful either, you can drop it completely.

Exceptions exceptions exceptions

Exception are expensive. Most of your time is spent here:

try:
    # If we hit a previously calculated value, break
    num = square_dict[num]
    break
except: 
    temp_nums.append(num)
    num = square_digits(num)

But if we reorganize this to a non-throwing version:

store = square_dict.get(num, None)
if store is not None:
    num = store
    break
else:
    temp_nums.append(num)
    num = square_digits(num)

Runtime drops from 45.7s on my box to 34.7s.

Memoization

This problem also lends itself well to memoization, which actually makes the logic a ton simpler since you can just write:

@memoize
def my_square_chain(i):
    if i == 1 or i == 89: 
        return i

    return my_square_chain(square_digits(i))

It’s shorter and easier to reason about, with the added benefit that it’s also a little bit faster (31.8s).

square_digits

You actually sum the square digits, making the code longer and less reusable, instead just yield them:

def square_digits(num):
    while num:
        yield (num % 10) ** 2
        num //= 10

Example usage:

>>> square_digits(146)
<generator object square_digits at 0x7f6a0b5930a0>
>>> list(_) # `_` is the previous value
[1, 16, 36]

In fact you may make your function even more versatile:

def digits(n):
    while n:
        yield n % 10
        n //= 10

Now you may use it in solving any problem that involves finding digits.

Also note that using my functions requires a slight change in the calling code:

  • square_digits(num) -> sum(square_digits(num))
  • square_digits(num) -> sum(i**2 for i in digits(num))

Leave a Reply

Your email address will not be published. Required fields are marked *