Problem

A number chain is created by continuously adding the square of the

digits in a number to form a new number until it has been seen before.For example,

- 44 → 32 → 13 → 10 → 1 → 1
- 85 → 89 → 145 → 42 → 20 → 4 → 16 → 37 → 58 → 89
Therefore any chain that arrives at 1 or 89 will become stuck in an

endless loop. What is most amazing is that EVERY starting number will

eventually arrive at 1 or 89.How many starting numbers below ten million will arrive at 89?

To tackle this problem I used the following procedure.

- set array[1] to 1
- set array[89] to 89
- loop through all starting numbers from 2 to 10,000,000
- for each starting number, calculate sum of squares until a previously calculated number is found
- mark off all numbers in the chain you have just created as either 89 or 1
- count the number of starting numbers that now have 89 as their array value

However even after using this relative clever strategy, my code still takes half a minute to compute. Is there something wrong I have done in my logic, or are there any other obvious improvements? I also think this strategy uses a large chunk of memory.

```
def square_digits(num):
# Squares the digits of a number, eg 44=4^2+4^2=32
total = 0
while num:
total += (num % 10) ** 2
num //= 10
return total
def single_chain(num, square_dict):
# Evaluates a single chain untill 1 or 89
temp_nums = []
while True:
if num in [1, 89]:
break
try:
# If we hit a previously calculated value, break
num = square_dict[num]
break
except:
temp_nums.append(num)
num = square_digits(num)
for i in temp_nums:
# Backtrack through the chain saving each number
square_dict[i] = num
return num == 1, square_dict
def pre_calculation(limit):
# Precalculates the first values
square_dict = dict()
for i in range(1,limit+1):
num = i
while num not in [1,89]:
num = square_digits(num)
if num == 1:
square_dict[i] = 1
else:
square_dict[i] = 89
return square_dict
def square_chains(limit, square_dict):
# Finds the number of chains ending in 1 and 89
count_1 = 0
count_89 = 0
for i in range(1, limit):
boolean, square_dict = single_chain(i, square_dict)
if boolean:
count_1 += 1
else:
count_89 += 1
print "Chains ending in 1: ", count_1
print "Chains ending in 89: ", count_89
if __name__ == '__main__':
square_dict = pre_calculation(9*9*7)
square_chains(10**7,square_dict)
```

Solution

I already posted the incremental improvement answer. Here’s the home run answer.

Consider the number 4,666,777. This number happens to chain into 89. That takes some amount of work to figure out. But eventually we get there. What does this tell us? Since we’re only interested in the sum of the squares of the digits, the actual ordering of the digits is irrelevant. That is… once we know that 4,666,777 is valid, we also know that 6,466,777 is valid, and 7,664,776 is valid, and … All 140 unique permutations of the digits 4666777 are things we want to count. The key is: once we’re done with 4666777, we do not even need to consider the other 139!

There are only 11,440 unique digit combinations from 1 to 10,000,000. Any solution checking all of them is thus doing ~900x more work than necessary. We can use `itertools.combinations_with_replacement`

to get the unique digit combinations, and then use `itertools.groupby`

to help determine how many such combinations there are.

Still with the memoized `my_square_chain`

:

```
def euler89():
count_89 = 0
fact7 = fact(7)
digits = range(10)
for num in itertools.combinations_with_replacement(digits, 7):
cur = sum(d**2 for d in num)
if cur > 0 and my_square_chain(cur) == 89:
count = fact7
for _, g in itertools.groupby(num):
count /= fact(len(list(g)))
count_89 += count
return count_89
```

This runs in **0.120s** on my box. A performance improvement of 265x from my incremental changes, and 381x from the original solution.

This won’t affect performance, but your return choice for `single_chain`

is odd. Since everything in python is a reference, you don’t need to return `square_dict`

. It’s unnecessary, you could just return `num`

. Then on the call side:

```
if single_chain(i, square_dict) == 1:
count_1 += 1
else:
count_89 += 1
```

reads a bit better. Your `pre_calculation`

doesn’t actually do anything useful either, you can drop it completely.

**Exceptions exceptions exceptions**

Exception are expensive. Most of your time is spent here:

```
try:
# If we hit a previously calculated value, break
num = square_dict[num]
break
except:
temp_nums.append(num)
num = square_digits(num)
```

But if we reorganize this to a non-throwing version:

```
store = square_dict.get(num, None)
if store is not None:
num = store
break
else:
temp_nums.append(num)
num = square_digits(num)
```

Runtime drops from 45.7s on my box to 34.7s.

**Memoization**

This problem also lends itself well to memoization, which actually makes the logic a ton simpler since you can just write:

```
@memoize
def my_square_chain(i):
if i == 1 or i == 89:
return i
return my_square_chain(square_digits(i))
```

It’s shorter and easier to reason about, with the added benefit that it’s also a little bit faster (31.8s).

`square_digits`

You actually sum the square digits, making the code longer and less reusable, instead just `yield`

them:

```
def square_digits(num):
while num:
yield (num % 10) ** 2
num //= 10
```

Example usage:

```
>>> square_digits(146)
<generator object square_digits at 0x7f6a0b5930a0>
>>> list(_) # `_` is the previous value
[1, 16, 36]
```

In fact you may make your function even more versatile:

```
def digits(n):
while n:
yield n % 10
n //= 10
```

Now you may use it in solving any problem that involves finding digits.

Also note that using my functions requires a slight change in the calling code:

`square_digits(num)`

->`sum(square_digits(num))`

`square_digits(num)`

->`sum(i**2 for i in digits(num))`