Solving recursions with memoization

1. Introduction

The main idea of recursion is to do just a tiny bit of work in a function, and then pass on a slightly simpler or smaller problem to another call of a the same function.1It’s a bit like a manager: do as little as possible (but just enough that it is clear that you contributed), and then push the problem to another manager. As such, recursion is an elegant and general method. However, if used careless, the amount of computation time can grow very quickly. To prevent this, we can use a truly great, but simple, idea: instead of recomputing intermediate results, store these results and look them up when we need them at a later moment. In this chapter we show how to use this technique, which is called memoization.[2You should know that memoization lies at the heart of how spreadsheet programs, such as excel, work.

2. Why store results?

2.1. Two toy examples

We can use recursion to compute the factorial \(n!\) of the integer \(n\), like so.

def f(n):
    if n == 1:
        return 1
    return n * f(n - 1)

print(f(4))
24

Note that, when using recursion, always think about the stopping condition. If you miss a corner case, your code will keep running.

Exercise 1:

What would happen if you would use this code to compute f(1/2)?

Solution

The program will claim more and more memory, and your computer will crash. Thus, stopping conditions should be complete, and in professional code you should always check on weird input.

Note further that in the code, we first check the boundary condition, i.e., n == 1, and only when we fell through all these conditions, we call the recursion.

Now we ask how often this function is called to compute for instance \(5!\).

def f(n):
    global no_times_called
    no_times_called += 1
    if n == 1:
        return 1
    return n * f(n - 1)


no_times_called = 0
f(5)
print(no_times_called)
5

Obviously, to compute \(5!\), we have to call the function f also \(5\) times. Of course, in this example it’s trivial, but this does not hold always.

The number of Rabbit pairs increases per generation, supposedly, according to Fibonacci’s rule:

\begin{equation} F(n) = F(n-1) + F(n-2),\quad F(0) = F(1) = 1. \end{equation}

Let’s implement this as a recursion, and see how often the function is called.

def Fibb(n):
    global no_times_called
    no_times_called += 1
    if n == 0 or n == 1:
        return 1
    return Fibb(n - 1) + Fibb(n - 2)


for n in range(20):
    no_times_called = 0
    Fibb(n)
    print(f"{n=:5d}, {no_times_called=:8d}")
n=    0, no_times_called=       1
n=    1, no_times_called=       1
n=    2, no_times_called=       3
n=    3, no_times_called=       5
n=    4, no_times_called=       9
n=    5, no_times_called=      15
n=    6, no_times_called=      25
n=    7, no_times_called=      41
n=    8, no_times_called=      67
n=    9, no_times_called=     109
n=   10, no_times_called=     177
n=   11, no_times_called=     287
n=   12, no_times_called=     465
n=   13, no_times_called=     753
n=   14, no_times_called=    1219
n=   15, no_times_called=    1973
n=   16, no_times_called=    3193
n=   17, no_times_called=    5167
n=   18, no_times_called=    8361
n=   19, no_times_called=   13529

Compared to the recursive algorithm to compute \(n!\), this algorithm clearly explodes.

Exercise 2:

Why is that?

Solution

The problem is that we don’t store intermediate results, but recompute everything every time.

2.2. Memoizaton to the rescue

What if we would store intermediate results rather then recompute them? The concept of storing intermediate results is known as memoization, or cashing, or hashing. Python provides the function cache in the functools library to exactly do this. Let’s see how to use it, and what it does.

from functools import cache


@cache
def Fibb(n):
    global no_times_called
    no_times_called += 1
    if n == 0 or n == 1:
        return 1
    return Fibb(n - 1) + Fibb(n - 2)


no_times_called = 0
Fibb(19)
print(f"{no_times_called=:8d}")
no_times_called=      20

Rather than doing 13529 computations, we now just need 20 calls! Amazing, isn’t it? If you don’t believe how useful memoization is, try your hand at F(300) with and without memoization. 3Without memoization, you’ll be dead before knowing the answer.

There are some remarks in order.

  • There is a closed form solution to compute the \(n\)th Fibonacci number. Of course we could have used that, but here we wanted to see how memoization serves to dramatically reduce numerical work.
  • Here we don’t explain how memoization is implemented. It’s not hard, though, and there are excellent explanations on the web.

3. Solving a betting game

Peter and Paul bet for a dollar on the outcome of a coin throw. Peter wins the dollar if the coin lands heads; otherwise Paul wins the dollar. Peter starts with \(i\) dollars, Paul with \(n-i\) dollars. They play this game for at most \(t\) rounds. Peter wins the game when he owns all \(n\) dollars within the \(t\) playing rounds. Otherwise, Paul wins the game if he wins all \(n\) dollars or the number of rounds played equals \(t\). What is the probability \(u(t,i)\) that Peter wins, if the coin lands heads with probability \(p\)?

We can use first-step analysis to decompose \(u(t,i)\). After one throw, \(t-1\) rounds remain. Therefore, \(u(t, i)\) satisfies the recursion:

\begin{align*} u(t, i) &= p u(t-1, i+1) + q u(t-1, i-1), \\ u(t, 0) &= 0, \quad t\geq 0, \\ u(0, i) &= 0, \quad 0\leq i < n, \\ u(t, n) &= 1, \quad t\geq 0. \end{align*}
Exercise 3:

Explain these equations4Do take particular notice of the boundary conditions. Are they complete, do we always reach them?.

Solution

At the start, \(t\) rounds remain. After one throw, \(t-1\) rounds remain, and if Peter won that round, which happens with probability \(p\), he owns \(i+1\) dollars, otherwise he owns \(i-1\) dollars. For the first boundary condition, if at any round it happens that \(i=0\), then Peter ran out of dollars, Paul won all dollars and thereby the game. Thus, in that case, the probability that Peter wins is zero. For the second boundary condition, if there are no more rounds, that is, \(t=0\), and Peter does not own all dollars, that is, \(i

from functools import cache

p = 26/50
q = 1 - p
n = 10

@cache
def u(t, i):
    if i == n:
        return 1
    if i == 0 or t == 0:
        return 0
    return p * u(t - 1, i + 1) + q * u(t - 1, i - 1)

print(u(t=100, i=5))
0.5939863359639174

You see how clean the code stays? Moreover, the code is conceptually identical to the mathematical specification. We don’t need any extra documentation besides the maths!

Before continuing with computing the expected duration of the game, let us suppose for the moment that we would remove the dependence on \(t\) in \(u(t, i)\). Then we get the recursion:

\begin{align*} u(i) &= p u( i+1) + q u_{i-1}, &u(n) &= 1, & u(0) &= 0. \end{align*}

Instead of a boundary problem with one end (for \(t=0\)), we have here a two-point boundary value problem. For this specific example it’s simple to find an analytic solution, but in general two-point boundary value problems are numerically much harder to solve than differential equations with only initial conditions.

What is the expected duration of the game? Clearly, the game ends when the boundary is hit at \(0\) or \(n\), or when there are no further rounds, i.e,. \(t=0\). The recursions for this question are also simple. Let \(v(t, i)\) be the expected duration of the game when there are at most \(t\) rounds left and Peter starts with \(i\) dollars. Then,

\begin{align*} v(t, i) &= 1 + p v(t-1, i+1) + q v(t-1, i-1), \\ v(t, n) &= v(t,0) = 0, \quad t\geq 0, \\ v(0, i) &= 0, \quad \text{for all } i. \end{align*}
from functools import cache

p = 26/50
q = 1 - p
n = 10


@cache
def v(t, i):
    if i == n or i == 0 or t==0:
        return 0
    return 1 + p * v(t - 1, i + 1) + q * v(t - 1, i - 1)


print(v(100, 5))
24.528910715863503

4. Some further interesting problems

4.1. Content of a Discrete Hyper Pyramid

We like to compute the number of possibilities \(\P{n, N}\) for \(x = (x_1, \ldots x_n)\) such that \(x_i \in \{0,1,\ldots, N\}\) and \(\sum_i x_i \leq N\). It is easy to see that \(\P{n, N}\) satisfies the recursion: \[\P{n, N} = \sum_{i=0}^N \P{n-1, N-i},\] with boundary conditions \(\P{1, N} = N+1\) for all \(N\). Note that by using the summation above, this condition can be replaced by \(\P{0, N} = 1\) for all \(N\).

Computing \(\P{n,N}\) is easy when we use memoization. In fact, we can code the compuation in nearly one line!

from functools import lru_cache

@lru_cache(maxsize=128)
def P(n, N):
   return (n==0) or sum( P(n-1, N-i) for i in range(N+1) )

n=5
N=80
print(P(n = 5, N = 80))
32801517

4.2. A probability problem

We throw multiple times with a coin that lands heads with probability \(p\). What is the probability \(\P{n,k}\) to see at least \(k\) heads in row when you throw a coin \(n\) times?

A bit of thinking shows that \(\P{n,k}\) must satisfy the recursion \[\P{n,k} = p^k + \sum_{i=1}^k p^{i-1} q\, \P{n-i,k},\] because it is possible to throw \(k\) times heads from the first throw, but otherwise you throw \(i\), \(i < k\), times a heads, then a tails, after which you have to start all over again.

Reasoning similarly, the expected number times \(\E{n,k}\) to see at least \(k\) heads in row when you throw a coin \(n\) times must satisfy the recursion \[\E{n,k} = p^k(1+\E{n-k,k}) + \sum_{i=1}^k p^{i-1} q\, \E{n-i,k}.\]

p = 0.5
q = 1. - p

@lru_cache(maxsize=128)
def P(n, k):
    """
    probability to see at least k heads in row when a coin is thrown n times
    """
    if n < k:
        return 0
    else:
        return sum(P(n-i,k) * p**(i-1) * q for i in range(1,k+1)) + p**k

@lru_cache(maxsize=128)
def E(n, k):
    """
    expected number of times to see at least k heads in row when a coin is thrown n times
    """
    if n < k:
        return 0
    else:
        tot = sum(E(n-i,k) * p**(i-1) *q for i in range(1,k+1))
        tot += p**k * (1 + E(n-k,k))
        return tot



k = 2

for n in range(k,10):
    print(n, P(n,k), E(n,k))

4.3. An interesting exercise

The next exercise challenges you generalize what we discussed above.

Exercise 4:

We draw, with replacement, balls, numbered 1 to \(N=45\), from an urn, but \(6\) at a time (not just one). Find a recursion to compute the expected number \(\E T\) of draws necessary to see all \(N=45\) balls, and use memoization to compute the result.

Solution

Write \(T_{n}\) for expected time to finish given that we have seen \(n\) different balls. Take \(6\) balls. If we would know the number \(k\) of new balls drawn, then \(T_{n} = 1 + T_{n+k}\). What is the probability to draw \(k\) new balls out of the \(6\) we pick? This must be

\begin{equation} {n \choose 6-k}{N-n \choose k}\big/{N \choose 6}. \end{equation}

Therefore, when \(N-n\geq 6\),

\begin{align*} T_{n} = 1 + \sum_{k=0}^{6} \frac{{n \choose 6-k}{N-n \choose k}}{{N \choose 6}} T_{n+k}. \end{align*}

This formula is not OK when are just \(2\) new balls as in that case we cannot pick \(k=6\) new balls. In general, we can pick \(k=\min\{6, N-n\}\) new balls. Hence,

\begin{align*} T_{n} &= 1 + \sum_{k=0}^{\min\{6, N-n\}} \frac{{n \choose 6-k}{N-n \choose k}}{{N \choose 6}} T_{n+k} \\ &\implies \\ T_{n} - \frac{{n \choose 6}{N-n \choose 0}}{{N \choose 6}} T_{n} &=1 + \sum_{k=1}^{\min\{6, N-n\}} \frac{{n \choose 6-k}{N-n \choose k}}{{N \choose 6}} T_{n+k} \\ &\implies \\ T_{n}\left( {N\choose 6} -{n \choose 6} \right) &={N\choose 6} + \sum_{k=1}^{\min\{6, N-n\}}{n \choose 6-k}{N-n \choose k} T_{n+k} . \end{align*}
from math import comb
from functools import cache

N = 45

@cache
def T(n, m):
    if n >= N:
        return 0
    res = comb(N, m)
    for k in range(1, min(m, N - n) + 1):
        P = comb(n, m - k)
        P *= comb(N - n, k)
        res += P * T(n + k, m)
    return res / (comb(N, m) - comb(n, m))

print(T(0, 6))
31.497085595869386

5. Summary

With memoization we can speed up recursive computations tremendously. In fact, when the recursions are not completely trivial, such as for the compuation of \(n!\), memoization determines whether the recursive computions will work or not. As we have seen in the code examples above, using this technique is very simple: just put the @cache decorator above a Pythonfunction. There is a point of warning too. In complex situations we need to think about the depth of the recursion. When this becomes a problem, we can chop up the compuations in sub steps. By the time you get to this level of problems, you’ll know how to handle this too.