Memoization with Fibonacci and Collatz Sequences

What is memoization? Simply put, memoization is the process of designing a program to store intermediate results to make it more efficient. Especially in a situation where we have a recursive function or where costly computations would need to be repeated several times, memoization can greatly speed up our calculations. Here I will outline this process in Python using the Fibonacci and Collatz sequences as examples.

The Fibonacci sequence - A naive implementation

The Fibonacci sequence is easy to describe in words. We start with the initial values 0 and 1, then each successive term is calculated by adding the last two terms. We have:

\begin{align} F_0 &= 0\\ F_1 &= 1\\ F_n &= F_{n-1} + F_{n-2} \end{align}

The first few terms of the sequence are: $$0,\;1,\;1,\;2,\;3,\;5,\;8,\;13,\;21,\;34,\;55,\;89,\;144,\; \ldots$$

We can easily see the recursive nature of this sequence. We have $0+1=1$, $1+1=2$, $1+2=3$, $2+3=5$, $\dots$, with this pattern continued indefinitely. We can easily define this function in Python:

 

def fib(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fib(n-1) + fib(n-2)

 

Let's take a look at the performance of the function by calculating the first fifty terms in the sequence and recording the times for each function call:

 

from time import time

times = []

for n in range(50):
start = time()
fib(n)
stop = time()
duration = stop-start
times.append((n, duration))

 

Looking at a graph of these times: What's going on here? We're just adding numbers, something that shouldn't be that expensive of an operation, but our execution times are huge and seeing what looks like exponential increase. In total this took over an hour to run on my desktop!

The problem is with the way that we have defined our function. By directly utilizing the recursion of the Fibonacci sequence in a function call, we have inadvertently made many more calculations than are needed. For instance, let's think about what is happening when we call fibs(5). Python does the following calculations, where each line roughly represents a step in the recursion:

\begin{align} \text{fibs}(5) &= \text{fibs}(4) + \text{fibs}(3)\\ &= \text{fibs}(3) + \text{fibs}(2) + \text{fibs}(2) + \text{fibs}(1)\\ &= \text{fibs}(2) + \text{fibs}(1) + \text{fibs}(1) + \text{fibs}(0) + \text{fibs}(1) + \text{fibs}(0) + 1\\ &= \text{fibs}(1) + \text{fibs}(0) + 1 + 1 + 0 + 1 + 0 + 1\\ &= 1 + 0 + 1 + 1 + 0 + 1 + 0 + 1 \\ &= 5 \end{align}

Look at how many calls we are making to the function! Even for such a low index, we call the function a dozen times. It is clear that we are not making an efficient use of this sequence's properties. There are multiple instances of the same value being recalculated in a single call of the function. None of the function calls actually evaluate to an integer except fibs(0) and fibs(1), and even these are called redundantly. More than that, this process starts all over when we need to calculate fibs(6).

The Fibonacci sequence - using Memoization

What we need is a clean way to implement a dictionary that records all intermediate calls of the function so that each index of the sequence is calculated exactly once, and any future calls use this dictionary.

Here is where we can implement memoization. Surprisingly, we are going to be able to use the same function but will be altering it into a class that has access to a dictionary of previously calculated values. We start by defining the following class:

 

class Memoize:
def __init__(self, fn):
self.fn = fn
self.memo = {}

def __call__(self, *args):
if args not in self.memo:
self.memo[args] = self.fn(*args)
return self.memo[args]

 

Let's break down what's happening. The method __init__ initializes the class by

- storing the passed function as self.fn
- creating the dictionary self.memo

The method __call__ is executed when the instance is called and checks if self.memo contains the function arguments then:

- if it does, we can simply return the value from the dictionary
- if not, the function is evaluated and added to the dictionary for future use

To initialize this class with the function fib we do the following:

 

@Memoize
def fib(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fib(n-1) + fib(n-2)

 

Here @Memoize is a decorator. This is simply a shorter way of writing fib = Memoize(fib) to replace our original function with the new class. Now how will fibs(5) perform as compared to above? If fibs(4) and fibs(3) are in self.memo it would be a single line!

For instance, after evaluating fibs(4) we can see that self.memo contains:

 

{(1,): 1, (0,): 0, (2,): 1, (3,): 2, (4,): 3}

 

so that if fibs(5) is called, the function will lookup fibs(4) and fibs(3) in self.memo instead of the longer calculation shown earlier.

The performance benefit is astounding. Using timeit to measure our new function, we can produce the first 100,000 Fibonacci numbers with average time:

 

%%timeit
@Memoize
def fib(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fib(n-1) + fib(n-2)

times = []

for n in range(100000):
start = time()
fib(n)
stop = time()
duration = stop-start
times.append((n, duration))

 

Output:

 

291 ms ± 7.94 ms per loop (mean  std. dev. of 7 runs, 1 loop each)

 

(Please note that I am using Jupyter, the %% is a built-in magic command)

We can make this faster, but I'll save that discussion for a future post as it is more related to mathematical rather than memory manipulation.

One last note that these numbers are huge, with the last number $F_{99,999} \approx 1.60528 \times 10^{20898}$ being over 20,000 digits long! I tried to compute up to the one-millionth Fibonacci but my computer became unresponsive.

Lengths of Collatz Sequences

Now that we have a bit of a feel for things, let's look at a slightly more complex sequence. Consider the following function:

$$f(n) = \begin{cases} \frac{n}{2} &\text{if } n \equiv 0 \pmod{2}\\[4px] 3n+1 & \text{if } n\equiv 1 \pmod{2} .\end{cases}$$

We use this function to create a sequence by continually applying $f$ until we reach a value of 1 (which causes a loop back to 1). For instance, we have:

$$13 \rightarrow 40 \rightarrow 20 \rightarrow 10 \rightarrow 5 \rightarrow 16 \rightarrow 8 \rightarrow 4 \rightarrow 2 \rightarrow 1$$

The Collatz Conjecture is the statement that any initial starting value creates a sequence that will end at 1. Despite the apparent simplicity of its statement, this problem is considered one of the great open problems in mathematics, with even the great mathematician Erdős being quoted as saying: "Mathematics is not yet ripe enough for such questions."

Nevertheless, it is easy to define the sequence in Python. Following Project Euler's problem 14, consider the following question: Which starting number, under one million, produces the longest sequence?

I can define the following two functions:

 

def collatz(n):
if n == 1:
return 1
elif n % 2 == 0:
return n//2
else:
return 3*n + 1

def collatz_sequence(n):
sequence = [n]

while 1 not in sequence:
n = collatz(n)
sequence.append(n)

return len(sequence)

 

and brute force the solution:

 

lengths  = [collatz_sequence(n) for n in range(1, 1000000)]
solution = lengths.index(max(lengths))

 

This takes approximately a minute and a half on my desktop, returning that the longest chain for initial values under one million begins at 837799 and produces a sequence of 525 numbers. This isn't bad, but we can do better (even without thinking about the math more) by using memoization.

Lengths of Collatz Sequences Using Memoization

This is a little harder to mentally visualize, so let's take a look at a diagram, which shows all sequences with length less than 20:

One thing we can notice is that there is a lot of repetition in our sequences, as they form large "branches" that split off at various points. For instance, every sequence in the diagram includes the subsequence $$16 \rightarrow 8 \rightarrow 4 \rightarrow 2 \rightarrow 1$$

So imagine for a moment that I have already evaluated collatz_sequence(16) and seen that it has length 5. Then if I were evaluating collatz_sequence(32) or collatz_sequence(5), the starting points that have 16 as their next step, I know that the length will evaluate to $1+5=6$, saving me several addition and subtraction operations. This is a perfect oppurtunity for memoization.

I can leave the collatz function as it is, but will need to slightly alter the memoization class and the function that computes the sequence length. I actually can simplify the memoization class to:

 

class Memoize:
def __init__(self, fn):
self.fn = fn
self.memo = {}

def __call__(self, *args):
self.fn(*args, self.memo)

 

Notice that instead of the __call__ method checking self.memo it just passes the dictionary as an arguement to our function. For this problem, I just felt it was easier to rewrite the function with the memoization dictionary as an arguement. So I will now make the sequence function:

 

@Memoize
def collatz_sequence(n, cache):
if n == 1:
cache[n] = 1
return 1

copy = n
sequence = [n]

while 1 not in sequence:
n = collatz(n)
if n in cache:
cache[copy] = len(sequence) + cache[n]
break
else:
sequence.append(n)

 

Let's break down what's happening. First in lines 2 through 8:

- If the initial value is 1, we add the key-value pair 1:1 to self.memo
- We make a copy of the inital number for later use with self.memo
- Initialize the list that will hold our Collatz sequence

Then in the while loop:

- We continue so long as our sequence has not reached 1
- We calculate the next value of the Collatz sequence
- If the next value is in self.memo we know that this is the remaining length
- If the next value is not in self.memo we continue by appending to the sequence

Note that this class is meant to be called in numerical order to take advantage of the memotization. As an example, if we compute the length of the first 13 sequences with:

 

for n in range(1, 14):
collatz_sequence(n)

 

then collatz_sequence.memo will contain:

 

{1: 1, 2: 2, 3: 8, 4: 3, 5: 6, 6: 9, 7: 17, 8: 4, 9: 20, 10: 7, 11: 15, 12: 10, 13: 10}

 

The performance greatly improves. Computing the lengths and sequences with an initial value less than one million and again measuring with timeit:

 

%%timeit
@Memoize
def collatz_sequence(n, cache):
if n == 1:
cache[n] = 1
return 1

copy = n
sequence = [n]

while 1 not in sequence:
n = collatz(n)
if n in cache:
cache[copy] = len(sequence) + cache[n]
break
else:
sequence.append(n)

return len(sequence)

for n in range(1, 1000000):
collatz_sequence(n)

 

Output:

 

2.53 s ± 22.5 ms per loop (mean  std. dev. of 7 runs, 1 loop each)

 

A Small Improvement

However, we can easily make a small change that will give just a bit more improvement. Look at the dictionary above. What is it missing? We calculated through the sequence starting with 13, but we didn't add any of these intermediate values to our dictionary! We can add a loop to capture these as well:

 

@Memoize
def collatz_sequence(n, cache):
if n == 1:
cache[n] = 1
return 1

copy = n
sequence = [n]

while 1 not in sequence:
n = collatz(n)
if n in cache:
cache[copy] = len(sequence) + cache[n]
break
else:
sequence.append(n)

for index, value in enumerate(sequence[::-1]):
cache[value] = cache[n] + index + 1

 

Now if we run the first 13 sequences again, self.memo will contain:

 

{1: 1, 2: 2, 3: 8, 4: 3, 8: 4, 16: 5, 5: 6, 10: 7, 6: 9, 7: 17, 20: 8, 40: 9, 13: 10, 26: 11, 52: 12, 17: 13, 34: 14, 11: 15, 22: 16, 9: 20, 14: 18, 28: 19, 12: 10}

 

We can visualize the change that we have made to the function by looking at the below graph, which is just a piece of the larger graph above. I have included all sequences that are generated from initial values 1 through 13. I have colored the initial values in blue and all other included numbers in red. Our first version failed to account for these additional numbers that appear in the sequence but are outside our range of initial values. We can see that in this particular case that we can store nearly twice as many values for future use. Finally, we can run this new version for initial values of less than one million and use timeit again:

 

%%timeit
@Memoize
def collatz_sequence(n, cache):
if n == 1:
cache[n] = 1
return 1

copy = n
sequence = [n]

while 1 not in sequence:
n = collatz(n)
if n in cache:
cache[copy] = len(sequence) + cache[n]
break
else:
sequence.append(n)

for index, value in enumerate(sequence[::-1]):
cache[value] = cache[n] + index + 1

for n in range(1, 1000000):
collatz_sequence(n)

 

Output:

 

1.99 s ± 43 ms per loop (mean  std. dev. of 7 runs, 1 loop each)

 

Conclusion

As we have seen, the use of memoization can have a drastic effect on the performance of a program. At its heart, all of these optimizations revolve around understanding and exploiting the structure of a particular problem with the use of an appropriate data structure that reduces repeating calculations when at all possible.

As we will see in a later post, this same idea of exploiting structure has vast generalizations that allow us the work with general mathematical frameworks and translate these into data structures that take advantage of the common structure between seemingly disparate problems.