Deriving Tail Recursive Fibonacci

Introduction

One of the cornerstones of Functional Programming is the use of structural recursion: writing functions that are recursive (they call themselves) and that model the recursion after the structure of the data they are operating on.

For example, given the standard Cons list:

sealed trait List[+A]
case class Cons[+A](head: A, tail: List[A]) extends List[A]
case object Nil extends List[Nothing]

we can write a length function using structural recursion:

def length[A](list: List[A]): Int =
  list match {
    case Nil => 0
    case Cons(_, tail) => 1 + length(tail)
  }

Most introductory material will then explain how certain functions are tail-recursive if and only if the recursive call is in tail position , or in other words if and only if they call themselves recursively as their final action.

Note that looks can be deceiving: although visually it looks like the recursive call in our length function above is in tail position, it is not, as we can see when we expand the execution trace of its recursive case:

case Cons(_, tail) => 1 + length(tail)
// is evaluated like:
val a = 1
val b = length(tail) // not in tail position!
a + b

Instead, a tail-recursive length looks like this:

def length[A](list: List[A], acc: Int = 0): Int =
  list match {
    case Nil => acc
    case Cons(_, tail) => length(tail, acc + 1)
  }

Tail-recursive functions tend to be less obvious for most people than their structural counterpart, so you might be wondering whether learning how to write them is worth the effort.

Well, there are at least of couple of good, practical reasons:

  1. The Scala compiler can run tail recursive functions in constant stack space, without incurring a StackOverflowException if the recursion is too deep.
  2. Some functions, when written with structural recursion, exhibit exponential complexity and are therefore prohibitively slow, and they can be made to run in linear time when written with tail recursion.

…but I have to be honest, if those were the only reasons, I probably wouldn’t bother writing a post about it.

Instead, I believe the main reason tail recursion is worth learning is because the shape of tail-recursive functions forces us to focus on the essential state needed to perform a certain computation, and the usefulness of identifying essential state extends far beyond just Functional Programming. For example, it is crucial in my day job when dealing with concurrent and distributed algorithms.

This post is divided into two parts: Part I will show a recipe to derive tail-recursive functions via algebraic manipulation, which is very useful if you know the recursive structure of your computation, but struggle to express it in tail-recursive terms. Part II will instead show an alternative method which focuses on state, and gets us closer to the skillset needed to model imperative, and ultimately concurrent and distributed computations.

Naive Fibonacci

The running example for this post will be a function to compute the n-th number of the Fibonacci sequence:

0, 1, 1, 2, 3, 5, 8, 13, 21, …

which is described by the following recurrence relation:

F(0) = 0
F(1) = 1
F(n) = F(n - 1) + F(n - 2)

We can easily translate that into Scala via structural recursion, noting that it has two base cases, and that we will pretend that Scala’s Int cannot be negative:

def fib(n: Int): Int =
 n match {
   case 0 => 0
   case 1 => 1
   case n => fib(n - 1) + fib(n - 2)
 }

Fibonacci is often used as a motivating example to learn tail recursion because the version we’ve just written is exponential, and starts getting pretty slow as we increase n.

So, tail recursion to the rescue!

Part I: Algebraic Derivation

Let’s start from the naive fibonacci:

def fib(n: Int) = n match {
  case 0 => 0
  case 1 => 1
  case n =>
   def go(): Int = ???
   ???
}

As you can see, we have left the recursive case unspecified, and instead introduced a helper function called go, which is where all the actual recursion is going to happen. This is a very common style since tail-recursive functions tend to need additional parameters, which are hidden by the helper function.

In particular, each component of the recursive definition becomes a parameter, so with the recurrence relation:

F(n) = F(n - 1) + F(n - 2)

we will have parameters for n, F(n), F(n - 1), and F(n - 2).

We can find appropriate names for these parameters by putting them on a number line in order:

         F(n - 2)        F(n - 1)          F(n)            n
-------------|---------------|---------------|-------------> 
        secondLast         last          current        counter

which means our function becomes:

def fib(n: Int) = n match {
  case 0 => 0
  case 1 => 1
  case n =>
   def go(secondLast: Int, last: Int, current: Int, counter: Int): Int = ???
   ???
}

In order to call go, we need to find appropriate initial values for its arguments: the initial value for counter is 2, because iterations 0 and 1 are already covered by the base cases. Then, the other params are given by the Fibonacci series up to iteration 2:

0, 1, 1, …

so:

def fib(n: Int) = n match {
  case 0 => 0
  case 1 => 1
  case n =>
   def go(secondLast: Int, last: Int, current: Int, counter: Int): Int = ???
   go(secondLast = 0, last = 1, current = 1, counter = 2)
}

Now let’s write go. Our fib function has to return the value of the Fibonacci sequence at iteration n, or in other words, the current value if counter is equal to n.

If counter is not at n, then we advance, so last becomes secondLast, current becomes last, and we compute the next value of current with the definition of Fibonacci, which says the current Fibonacci number is equal to the sum of the last and secondLast Fibonacci numbers.

Note that because we have replaced recursive components with arguments to go, we can compute the next value of current without any recursion, we only need to recur in tail position by passing the updated arguments to go:

def fib(n: Int) = n match {
  case 0 => 0
  case 1 => 1
  case n =>
   def go(secondLast: Int, last: Int, current: Int, counter: Int): Int =
     if (counter == n) current
     else {
       val counterNext = counter + 1
       val secondLastNext = last // we move to the right
       val lastNext = current // we move to the right
       val currentNext = lastNext + secondLastNext // definition of Fibonacci
       go(
        secondLast = secondLastNext,
        last = lastNext,
        current = currentNext,
        counter = counterNext
       )
     }

   go(secondLast = 0, last = 1, current = 1, counter = 2)
}

and we’re done! We can apply several transformations to considerably simplify this function, but before we do that, let’s recap the two ideas we’ve used to write tail-recursive functions:

  • Every component of the recursive definition becomes an argument to a recursive helper.
  • At the end of the recursion, one of the arguments will hold the final result.

One final observation is that the recursion is no longer structural, i.e. the fact that the current Fibonacci number is defined recursively no longer corresponds to recursion in code. Instead, the logical recursion corresponds to state updates, and the recursion in code is only used to iterate our state transformations until we reach the desired results. This is the reason why the recursive call can be in tail position, but I dare say the link between recursive definitions and state is a far more interesting aspect of this type of function than the use of tail recursion in itself.

Simplification

Ok, now onto simplifying our implementation: the first transformation we can apply is inlining, i.e. replacing each name with its definition.This is safe to do because even though the algorithm is conceptually evolving variables, it’s expressed as immutable transformations for which inlining can always be done without changing behaviour. As an example:

go(..., currentNext, ...)
// but:
currentNext = lastNext + secondLastNext
// so:
go(..., lastNext + secondLastNext, ...)
// but:
lastNext = current
secondLastNext = last
// so:
go(..., current + last, ...)

Let’s inline all the definitions in go:

def fib(n: Int) = n match {
  case 0 => 0
  case 1 => 1
  case n =>
   def go(secondLast: Int, last: Int, current: Int, counter: Int): Int =
     if (counter == n) current
     else 
       go(
        secondLast = last,
        last = current,
        current = current + last,
        counter = counter + 1,
       )
   
   go(secondLast = 0, last = 1, current = 1, counter = 2)
}

but now we notice that secondLast is redundant, because it’s updated but never actually used to compute current: current is updated via current + last, and last doesn’t depend on secondLast.

So we can eliminate it to get to:

def fib(n: Int) = n match {
  case 0 => 0
  case 1 => 1
  case n =>
   def go(last: Int, current: Int, counter: Int): Int =
     if (counter == n) current
     else 
       go(
        last = current,
        current = current + last
        counter = counter + 1
       )

   go(last = 1, current = 1, counter = 2)
}

Now let’s eliminate the named arguments, and hoist go at the top. It’s tail recursive, and we can make sure by adding the tailrec annotation:

import annotation.tailrec

def fib(n: Int) = {
  @tailrec
  def go(last: Int, current: Int, counter: Int): Int =
    if (counter == n) current 
    else go(current, current + last, counter + 1)

  n match {
    case 0 => 0
    case 1 => 1
    case n => go(1, 1, 2)
  }
}

Finally, we can see that the pattern matching has redundant logic in the first two cases, and doesn’t use the pattern in the third, so we can replace it with an if. This is our final version:

import annotation.tailrec

def fib(n: Int) = {
  @tailrec
  def go(last: Int, current: Int, counter: Int): Int =
    if (counter == n) current 
    else go(current, current + last, counter + 1)

  if (n <= 1) n else go(1, 1, 2)
}

Part II: Operational Derivation

The derivation above used a lot of equational thinking, but often with tail recursion we can adopt a more operational mindset.

In fact, tail recursion can be understood as a translation from mutable state algorithms, one where imperative thinking is more explicit than in most imperative languages, in that for each variable x you distinguish x from its updated version x' (due to immutability), and you loop via a restricted GOTO (the recursive call) instead of using while.

So, let’s look at the recurrence relation again:

F(0) = 0
F(1) = 1
F(n) = F(n - 1) + F(n - 2)

This time, we start by recognising that the first two cases can be done with an if, and then we know there is going to be some iteration in the recursive case, which we can represent with a helper. At this point, the only information we have about the helper is its return type:

def fib(n: Int) = {
  def go(): Int = go()
  if (n <= 1) n else go()
}

The guiding principle when deriving tail-recursive functions via operational thinking is figuring out which state you need to keep track of, and add each piece of state as a parameter. Tail recursive helpers have to return a result which is updated during the iteration, so we can start by keeping track of that.

def fib(n: Int) = {
  def go(result: Int): Int = go(result)
  if (n <= 1) n else go(???)
}

To call go, we need to figure out the initial value of result. We know that the if returns the result of fib(n) when n == 0 and n == 1, so the initial value of result is fib(n) when n == 2, which is 1, as per:

0, 1, 1, …

def fib(n: Int) = {
   def go(result: Int): Int = go(result)
   if (n <= 1) n else go(1)
}

Next, we have to figure out when it is possible to return result, i.e. a termination condition. Sometimes this can be done as a predicate on result without any additional state, but in this case the result has to be returned once we reach the n-th iteration, which means we have to add a counter: Int parameter to keep track of which iteration we’re in.

Because the if covers iterations 0 and 1, so the initial value of counter will be 2:

def fib(n: Int) = {
  def go(result: Int, counter: Int): Int =
    if (counter == n) result
    else go(result, counter)
  if (n <= 1) n else go(result = 1, counter = 2)
}

Now that we have some state with initial values, we have to figure out how to update it before recurring:

def fib(n: Int) = {
  def go(result: Int, counter: Int): Int =
    if (counter == n) result
    else {
      val counterNext = counter + 1
      val resultNext = ???
      go(resultNext, counterNext)
    }
  if (n <= 1) n else go(result = 1, counter = 2)
}

Updating counter is trivial, but we don’t know what resultNext should be. Well, according to the definition of Fibonacci, it’s the sum of fib(n - 1) and fib(n - 2), and we already have fib(n - 1), it’s result!

To see why, consider that resultNext is the result, i.e. the fib(n), of the next iteration, which means that fib(n - 1) is the previous value from the point of view of the next iteration, i.e the result of the current iteration. Let’s update the code:

def fib(n: Int) = {
  def go(result: Int, counter: Int): Int =
    if (counter == n) result
    else {
      val counterNext = counter + 1
      val resultNext = result + ???
      go(resultNext, counterNext)
    }
  if (n <= 1) n else go(result = 1, counter = 2)
}

However, we simply cannot compute fib(n - 2), which means we are missing a parameter to track it, let’s call it last: Int. Understanding why this is the last value can be a bit tricky, but it’s the same idea as before: we need fib(n - 2) in the next iteration, which means we need fib(n - 1) in this iteration, i.e. the last value.

Now, last is going to need an initial value, and it’s also going to need to be updated on each iteration before the recursive call. Let’s start with the initial value, which is straightforward: we need fib(n - 1) when counter == 2, which is 1. The updated lastNext value represents, once again, the previous result from the point of view of the next iteration, i.e. the result of the current iteration.

So we have:

def fib(n: Int) = {
  def go(result: Int, counter: Int, last: Int): Int =
    if (counter == n) result
    else {
      val counterNext = counter + 1
      val resultNext = result + last
      val lastNext = result
      go(resultNext, counterNext, lastNext)
    }
  if (n <= 1) n else go(result = 1, counter = 2, last = 1)
}

And let’s apply the same refactoring as in Part I:

import annotation.tailrec

def fib(n: Int) = {
  @tailrec
  def go(result: Int, counter: Int, last: Int): Int =
    if (counter == n) result
    else go(result + last, counter + 1, result)

  if (n <= 1) n else go(1, 2, 1)
}

Conclusion

Although most didactic material covers the idea of tail recursion, actually writing tail-recursive functions is often left as an exercise to the reader.

This is a shame because, far from just being an optimisation, tail recursion is actually great at expressing tricky stateful logic, and really teaches us to think about state methodically.

So, tail recurse more, and see you next time!