Introduction
One of the cornerstones of Functional Programming is the use of structural recursion: writing functions that are recursive (they call themselves) and that model the recursion after the structure of the data they are operating on.
For example, given the standard Cons
list:
sealed trait List[+A]
case class Cons[+A](head: A, tail: List[A]) extends List[A]
case object Nil extends List[Nothing]
we can write a length
function using structural recursion:
def length[A](list: List[A]): Int =
list match {
case Nil => 0
case Cons(_, tail) => 1 + length(tail)
}
Most introductory material will then explain how certain functions are tail-recursive if and only if the recursive call is in tail position , or in other words if and only if they call themselves recursively as their final action.
Note that looks can be deceiving: although visually it looks like
the recursive call in our length
function above is in tail position,
it is not, as we can see when we expand the execution trace of its
recursive case:
case Cons(_, tail) => 1 + length(tail)
// is evaluated like:
val a = 1
val b = length(tail) // not in tail position!
a + b
Instead, a tail-recursive length
looks like this:
def length[A](list: List[A], acc: Int = 0): Int =
list match {
case Nil => acc
case Cons(_, tail) => length(tail, acc + 1)
}
Tail-recursive functions tend to be less obvious for most people than their structural counterpart, so you might be wondering whether learning how to write them is worth the effort.
Well, there are at least of couple of good, practical reasons:
- The Scala compiler can run tail recursive functions in constant
stack space, without incurring a
StackOverflowException
if the recursion is too deep. - Some functions, when written with structural recursion, exhibit exponential complexity and are therefore prohibitively slow, and they can be made to run in linear time when written with tail recursion.
…but I have to be honest, if those were the only reasons, I probably wouldn’t bother writing a post about it.
Instead, I believe the main reason tail recursion is worth learning is because the shape of tail-recursive functions forces us to focus on the essential state needed to perform a certain computation, and the usefulness of identifying essential state extends far beyond just Functional Programming. For example, it is crucial in my day job when dealing with concurrent and distributed algorithms.
This post is divided into two parts: Part I will show a recipe to derive tail-recursive functions via algebraic manipulation, which is very useful if you know the recursive structure of your computation, but struggle to express it in tail-recursive terms. Part II will instead show an alternative method which focuses on state, and gets us closer to the skillset needed to model imperative, and ultimately concurrent and distributed computations.
Naive Fibonacci
The running example for this post will be a function to compute the
n-th
number of the Fibonacci sequence:
0, 1, 1, 2, 3, 5, 8, 13, 21, …
which is described by the following recurrence relation:
F(0) = 0
F(1) = 1
F(n) = F(n - 1) + F(n - 2)
We can easily translate that into Scala via structural recursion,
noting that it has two base cases, and that we will pretend that
Scala’s Int
cannot be negative:
def fib(n: Int): Int =
n match {
case 0 => 0
case 1 => 1
case n => fib(n - 1) + fib(n - 2)
}
Fibonacci is often used as a motivating example to learn
tail recursion because the version we’ve just written is exponential,
and starts getting pretty slow as we increase n
.
So, tail recursion to the rescue!
Part I: Algebraic Derivation
Let’s start from the naive fibonacci:
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(): Int = ???
???
}
As you can see, we have left the recursive case unspecified, and
instead introduced a helper function called go
, which is where all
the actual recursion is going to happen.
This is a very common style since tail-recursive functions tend to
need additional parameters, which are hidden by the helper function.
In particular, each component of the recursive definition becomes a parameter, so with the recurrence relation:
F(n) = F(n - 1) + F(n - 2)
we will have parameters for n
, F(n)
, F(n - 1)
, and F(n - 2)
.
We can find appropriate names for these parameters by putting them on a number line in order:
F(n - 2) F(n - 1) F(n) n
-------------|---------------|---------------|------------->
secondLast last current counter
which means our function becomes:
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int = ???
???
}
In order to call go
, we need to find appropriate initial values for
its arguments: the initial value for counter
is 2, because
iterations 0 and 1 are already covered by the base cases.
Then, the other params are given by the Fibonacci series up to iteration 2:
0, 1, 1, …
so:
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int = ???
go(secondLast = 0, last = 1, current = 1, counter = 2)
}
Now let’s write go
. Our fib
function has to return the value of
the Fibonacci sequence at iteration n
, or in other words, the
current
value if counter
is equal to n
.
If counter
is not at n
, then we advance, so last
becomes
secondLast
, current
becomes last
, and we compute the next value
of current
with the definition of Fibonacci, which says the
current
Fibonacci number is equal to the sum of the last
and
secondLast
Fibonacci numbers.
Note that because we have replaced recursive components with arguments
to go
, we can compute the next value of current
without any
recursion, we only need to recur in tail position by passing the
updated arguments to go
:
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else {
val counterNext = counter + 1
val secondLastNext = last // we move to the right
val lastNext = current // we move to the right
val currentNext = lastNext + secondLastNext // definition of Fibonacci
go(
secondLast = secondLastNext,
last = lastNext,
current = currentNext,
counter = counterNext
)
}
go(secondLast = 0, last = 1, current = 1, counter = 2)
}
and we’re done! We can apply several transformations to considerably simplify this function, but before we do that, let’s recap the two ideas we’ve used to write tail-recursive functions:
- Every component of the recursive definition becomes an argument to a recursive helper.
- At the end of the recursion, one of the arguments will hold the final result.
One final observation is that the recursion is no longer structural,
i.e. the fact that the current
Fibonacci number is defined
recursively no longer corresponds to recursion in code. Instead, the
logical recursion corresponds to state updates, and the recursion in
code is only used to iterate our state transformations until we reach
the desired results. This is the reason why the recursive call can be
in tail position, but I dare say the link between recursive
definitions and state is a far more interesting aspect of this type of
function than the use of tail recursion in itself.
Simplification
Ok, now onto simplifying our implementation: the first transformation we can apply is inlining, i.e. replacing each name with its definition.This is safe to do because even though the algorithm is conceptually evolving variables, it’s expressed as immutable transformations for which inlining can always be done without changing behaviour. As an example:
go(..., currentNext, ...)
// but:
currentNext = lastNext + secondLastNext
// so:
go(..., lastNext + secondLastNext, ...)
// but:
lastNext = current
secondLastNext = last
// so:
go(..., current + last, ...)
Let’s inline all the definitions in go
:
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else
go(
secondLast = last,
last = current,
current = current + last,
counter = counter + 1,
)
go(secondLast = 0, last = 1, current = 1, counter = 2)
}
but now we notice that secondLast
is redundant, because it’s updated
but never actually used to compute current
: current
is updated via
current + last
, and last
doesn’t depend on secondLast
.
So we can eliminate it to get to:
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else
go(
last = current,
current = current + last
counter = counter + 1
)
go(last = 1, current = 1, counter = 2)
}
Now let’s eliminate the named arguments, and hoist go
at the top.
It’s tail recursive, and we can make sure by adding the tailrec
annotation:
import annotation.tailrec
def fib(n: Int) = {
@tailrec
def go(last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else go(current, current + last, counter + 1)
n match {
case 0 => 0
case 1 => 1
case n => go(1, 1, 2)
}
}
Finally, we can see that the pattern matching has redundant logic in
the first two cases, and doesn’t use the pattern in the third, so we
can replace it with an if
. This is our final version:
import annotation.tailrec
def fib(n: Int) = {
@tailrec
def go(last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else go(current, current + last, counter + 1)
if (n <= 1) n else go(1, 1, 2)
}
Part II: Operational Derivation
The derivation above used a lot of equational thinking, but often with tail recursion we can adopt a more operational mindset.
In fact, tail recursion can be understood as a translation from
mutable state algorithms, one where imperative thinking is more
explicit than in most imperative languages, in that for each variable
x
you distinguish x
from its updated version x'
(due to
immutability), and you loop via a restricted GOTO
(the recursive
call) instead of using while
.
So, let’s look at the recurrence relation again:
F(0) = 0
F(1) = 1
F(n) = F(n - 1) + F(n - 2)
This time, we start by recognising that the first two cases can be
done with an if
, and then we know there is going to be some
iteration in the recursive case, which we can represent with a helper.
At this point, the only information we have about the helper is its
return type:
def fib(n: Int) = {
def go(): Int = go()
if (n <= 1) n else go()
}
The guiding principle when deriving tail-recursive functions via operational thinking is figuring out which state you need to keep track of, and add each piece of state as a parameter. Tail recursive helpers have to return a result which is updated during the iteration, so we can start by keeping track of that.
def fib(n: Int) = {
def go(result: Int): Int = go(result)
if (n <= 1) n else go(???)
}
To call go
, we need to figure out the initial value of result
. We
know that the if
returns the result of fib(n)
when n == 0
and n == 1
, so the initial value of result
is fib(n)
when n == 2
,
which is 1, as per:
0, 1, 1, …
def fib(n: Int) = {
def go(result: Int): Int = go(result)
if (n <= 1) n else go(1)
}
Next, we have to figure out when it is possible to return result
,
i.e. a termination condition. Sometimes this can be done as a
predicate on result
without any additional state, but in this case
the result has to be returned once we reach the n-th
iteration,
which means we have to add a counter: Int
parameter to keep track of
which iteration we’re in.
Because the if
covers iterations 0 and 1, so the initial value of
counter
will be 2:
def fib(n: Int) = {
def go(result: Int, counter: Int): Int =
if (counter == n) result
else go(result, counter)
if (n <= 1) n else go(result = 1, counter = 2)
}
Now that we have some state with initial values, we have to figure out how to update it before recurring:
def fib(n: Int) = {
def go(result: Int, counter: Int): Int =
if (counter == n) result
else {
val counterNext = counter + 1
val resultNext = ???
go(resultNext, counterNext)
}
if (n <= 1) n else go(result = 1, counter = 2)
}
Updating counter
is trivial, but we don’t know what resultNext
should be. Well, according to the definition of Fibonacci, it’s the
sum of fib(n - 1)
and fib(n - 2)
, and we already have fib(n - 1)
, it’s result
!
To see why, consider that resultNext
is the result
, i.e. the
fib(n)
, of the next iteration, which means that fib(n - 1)
is
the previous value from the point of view of the next iteration, i.e
the result
of the current iteration. Let’s update the code:
def fib(n: Int) = {
def go(result: Int, counter: Int): Int =
if (counter == n) result
else {
val counterNext = counter + 1
val resultNext = result + ???
go(resultNext, counterNext)
}
if (n <= 1) n else go(result = 1, counter = 2)
}
However, we simply cannot compute fib(n - 2)
, which means we are
missing a parameter to track it, let’s call it last: Int
.
Understanding why this is the last
value can be a bit tricky, but
it’s the same idea as before: we need fib(n - 2)
in the next
iteration, which means we need fib(n - 1)
in this iteration, i.e.
the last value.
Now, last
is going to need an initial value, and it’s also going to
need to be updated on each iteration before the recursive call.
Let’s start with the initial value, which is straightforward: we need
fib(n - 1)
when counter == 2
, which is 1.
The updated lastNext
value represents, once again, the previous
result from the point of view of the next iteration, i.e. the
result
of the current iteration.
So we have:
def fib(n: Int) = {
def go(result: Int, counter: Int, last: Int): Int =
if (counter == n) result
else {
val counterNext = counter + 1
val resultNext = result + last
val lastNext = result
go(resultNext, counterNext, lastNext)
}
if (n <= 1) n else go(result = 1, counter = 2, last = 1)
}
And let’s apply the same refactoring as in Part I:
import annotation.tailrec
def fib(n: Int) = {
@tailrec
def go(result: Int, counter: Int, last: Int): Int =
if (counter == n) result
else go(result + last, counter + 1, result)
if (n <= 1) n else go(1, 2, 1)
}
Conclusion
Although most didactic material covers the idea of tail recursion, actually writing tail-recursive functions is often left as an exercise to the reader.
This is a shame because, far from just being an optimisation, tail recursion is actually great at expressing tricky stateful logic, and really teaches us to think about state methodically.
So, tail recurse more, and see you next time!