Help support the author by donating or purchasing a copy of the book (not available yet)

Chapter 16 - Recursion & Quicksort

16.1 - Recursion

Before I go any further I want to say this, writing a recursive solution is DIFFICULT when starting out. Don’t expect it to work first time. Writing a recursive solution involves a whole new way of thinking and can in fact be quite off putting to new comers but stick with it and you’ll wonder how you ever found it so difficult.

So what exactly is recursion? A recursive solution is one where the solution to a problem is expressed as an operation on a simplified version of the sameproblem.

Now that probably didn’t make a whole lot of sense and it’s very tricky to explain exactly how it works so lets look at a simple example:

def reduce(x):
    return reduce(x-1)

So whats happening here? We have a function called reduce which takes a parameter x (an integer) and we want it to reduce it to ‘0’. Sounds simple so lets run it when we pass 5 as the parameter and see what happens:

>>> reduce(5)
RuntimeError: maximum recursion depth exceeded

An error? Well lets take a closer look at whats going on here:

reduce(5) calls reduce(4) which calls reduce(3)…….Thus our initial call reduce(5) is the first call in an infinite number of calls to reduce(). Each time reduce() is called, Python instantiates a data structure to represent that particular call to the function. This data structure is called a stack frame. A stack frame occupies memory. Our program attempts to create an infinite number of stack frames which would require an infinite amount of memory which we don’t have so our program crashes.

To fix this problem we need to add in something referred to as the base case. This is simply a check we add so that our function knows when to stop calling itself. The base case is normally the simplest version of the original problem and one we always know the answer to.

Lets fix our error by adding a base case. For this function we want it to stop when we reach 0.

def reduce(x):
    if x == 0:
        return 0
    return reduce(x-1)

So now when we run our program we get:

>>> reduce(5)

Great it works! Lets take another look at what is happening this time. We call reduce(5) which calls reduce(4)….which calls reduce(0). Ok stop here, we have hit our base case. Now what the function will do is return 0 to the previous call of reduce(1) which returns 0 to reduce(2)….which returns 0 to reduce(5) which returns our answer. There you go, your first recursive solution!

Let’s look at a more practical example. Lets write a recursive solution to find N! (or N factorial):

N factorial is defined as: N! = N * (N-1) * (N-2)…… * 2 * 1.

For example, 4! = 4 * 3 * 2 * 1

Our base case for this example is N = 0 as 0! is defined as 1. So lets write our function:

def factorial(n):
    if n == 0:
        return 1
    return n * factorial(n-1)

OK lets see that in action. We’ll call our function and pass 4 in for n:

factorial(4) calls factorial (3) ….. which calls factorial(0) which is our base case so we return 1 back to the previous call of factorial(1) which takes our 1 and multiplies it by 1 which evaluate to 1 which passes that back to factorial(2) which takes our 1 and multiplies it by 2 which evaluates to 2 which passes that back to factorial(3) which multiplies 3 2 to get 6 which passes that back to factorial(4) which multiples 4 6 which evaluates to 24 which is returned as our answer.

These are very simple cases and not super useful but being able to spot when a recursive solution may be implemented and then implementing that solution is a skill that will make you a better programmer. For certain problems recursion may offer an intuitive, simple, and elegant solution.

For example, you may implement a recursive solution to find the number of nodes in a tree structure, count how many leaf nodes are in a tree or even return whether a binary search tree is AVL balanced or not. These solutions tend to be quite short and elegant and take away from the long iterative approach you may have been taking.

As I’ve said, recursion isn’t easy but stick with it, it will click and seem so simple you’ll wonder how you ever found it so difficult.

16.2 - Memoization

We've written the solution to the Fibonacci problem a couple of times throughout this book. When writing those solutions we've used an iterative approach. We can however, take a recursive approach. It may be difficult to recognize when a recursive solution is an option but this is a great example of when a recursive solution is an option!

Let's look at a solution to the following problem:

Write a program that takes an integer n and computes the n-th Fibonacci number:

def fib(n):
    if n <= 1:
        return 1
        return fib(n-1) + fib(n-2)

Thats much simpler! In this example we'll take fib(0) to be 1 and fib(1) to be 1 also. Therefore our base case is if n is 0 or 1 we return 1. The recursive case comes straight from the definition of the Fibonacci sequence, that is:
fib(n)=fib(n − 1)+fib(n − 2)
However, you may have noticed that if you input anything greater than (roughly) 35 - 40 your program begins to really grind to a halt and take a VERY long time and anything above 45~50, well... go make a cup of coffee and come back and it MIGHT be finished. What's causing this? Let's look at a diagram of the first few recursive calls to fib().

We can see that when we call fib(5) we, at various points in our functions execution, call fib(1) 5 times! and fib(2) 3 times! This is unnecessary. Why should we have to recompute the value for fib(1) over and over again?

We can reduce the number of times we need to calculate various values by using a technique called memoization. Memoization is an optimization technique used to speed up our programs by caching the results of expensive function calls. A cache is essentially a temporary storage area and we can implement a cache in our fib() function to store previously calculated values and just pull them out of the cache when needed rather than making unnecessary function calls.

We can use a dictionary to implement a cache. This will give us O(1) for accessing and will drastically improve the performance of our function.

Let's look at how that's done:

def fib(n, cache=None):
    # Avoid default mutable argument trap!
    if cache == None:
        cache = {}
    # Base case
    if n <= 1:
        return 1
    # If value not in cache then calculate it and store in cache
    elif not n in cache:
        cache[n] = fib(n-1, cache) + fib(n-2, cache)
    # Return the value
    return cache[n]

We can now call fib(900) and get the answer straight away! By default, Python limits the recursion depth to 1000. We can override this but it's usually not a good idea!

Let's look at our tree now, this time for fib(6).

We can see from the tree that there are less calls to the function even though we are a value higher. This is a significant performance increase. Even calling from fib(5) we have saved ourselves 6 function calls!

16.3 - Quicksort part 1

In the following sections I'm going to look at a very important sorting algorithm: quicksort. Quicksort is a recursive sorting algorithm that employs a divide-and-conquer strategy. The Python built-in sorted() function uses a modified Quicksort algorithm. Divide-and-conquer is a problem solving strategy in which we continuously break down a problem into easier, more manageable sub problems and solve them.

In this section we'll look at sorting lists of numbers.

Since this is a divide-and-conquer algorithm we want to take a list of unsorted integers and split the problem down into two easier problems and then break each of those down…. and so on.

To achieve this I’ll first cover quicksorts core operation: partitioning. It works as follows:

>>> A = [6, 3, 17, 11, 4, 44, 76, 23, 12, 30]
>>> partition(A, 0, len(A)-1)
>>> print(A)
[6, 3, 17, 11, 4, 23, 12, 30, 76, 44]

So what happened here and how does it work? We need to pick some number as our pivot. Our partition function takes 3 arguments, the list, the first element in the list, and the pivot. What we are trying to achieve here is that when we partition the list, everything to the left of the pivot is less than the pivot and everything to the right is greater than the pivot. For the first partition seen above, 30 is our pivot. We'll always take our partition to be the last element in the list. After the partition we see some elements have changed position but everything to the left of 30 is less than it and everything to the right is greater than it.

So what does that mean for us? Well it means that 30 is now in its correct position in the list AND we now have two easier lists to sort. All of this is done in-place so we are not creating new lists.

Lets look at the code:

def partition(A, p, r):
   q = j = p
   while j < r:
     if A[j] <= A[r]:
       A[q], A[j] = A[j], A[q]
         q += 1
     j += 1
   A[q], A[r] = A[r], A[q]
   return q

The return q at the end isn’t necessary for our partition but it is essential for sorting the entire list. The above code works its way across the list A and maintains indices p, q, j, r.

p is fixed and is the first element in the list. r is the pivot and is the last element in the list. Elements in the range A[p:q-1] are known to be less than or equal to the pivot and everything from A[q-1:r-1] are greater than the pivot. The only indices that change are q and j. At each step we compare A[j] with A[r]. If it is greater than the pivot it is in the correct position so we increment j and move to the next element. If A[j] is less than A[r] we swap A[q] with A[j]. After this swap, we increment q, thus extending the range of elements known to be less than or equal to the pivot. We also increment j to move to the next element to be processed.

I encourage you to work through an example with a small list on paper to make this clearer.

16.4 - Quicksort part 2

Now onto the quicksort part. Remember it is a recursive algorithm so it will continuously call on partition() until there is nothing left to partition and all elements are in their correct positions. After the first partition we call partition() again, this time we call it on the list of elements to the left of the pivot and the list of elements to the right of the pivot.

Lets look at the code:

def quicksort(A, p, r):
   if r <= p:               # If r <= p then our list is sorted
   q = partition(A, p, r)   # partition incoming list
   quicksort(A, p, q-1)     # call quicksort again on everything to left of pivot
   quicksort(A, q+1, r)     # call quicksort again on everything to right of pivot
   return A

Its that simple. All we do here is check if the index of the pivot, is less than or equal to the index of the start of our list we want to partition. If it is we return as whatever list was passed does not need to be partitioned any further.

Otherwise, we partition the list A, and call quicksort again on the two new sub lists.

On line 3, we return without specifying anything to return. This is because the quicksort function is a procedure.

Quicksort works best on large lists that are completely scrambled. It has really bad performance on lists that are almost sorted. Or in Big-O notation, the best case (scrambled) is O(n log(n)) and in the worst case, (almost or completely ordered list) is O(n^2).

Again, I encourage you to try this on paper with a simple list. It will help clarify what is going on.

16.5 - Exercises

Important: These exercises are quite difficult! Make use of problem solving techniques to help you arrive at a solution. Sketching out what's happening on paper, it's a good way to help you out!

Some of these questions are the kind of questions you can expect if you get a recursion question when interviewing at companies such as Google, Facebook or Amazon.

Question 1

What is:

Question 2

Write a recursive function that adds up all the numbers from 1 to 100.

Question 3

Write a recursive function that takes an integer as an argument and returns whether or not that integer is a power of 2. Your function should return True or False

Question 4

Write a recursive function that takes a string as an argument and returns whether or not that string is a palindrome. Your function should return True or False.

Question 5

Write a recursive function that takes a list of integers as an argument and returns the maximum of that list.

Hint: Thinking recursively, the maximum is either the first element of the list or the maximum of the remaining list!

Help support the author by donating or purchasing a copy of the book (not available yet)

Previous Chapter - Next Chapter