## Introduction

One of the cornerstones of Functional Programming is the use of
*structural recursion*: writing functions that are *recursive* (they
call themselves) and that model the recursion after the structure of
the data they are operating on.

For example, given the standard `Cons`

list:

```
sealed trait List[+A]
case class Cons[+A](head: A, tail: List[A]) extends List[A]
case object Nil extends List[Nothing]
```

we can write a `length`

function using structural recursion:

```
def length[A](list: List[A]): Int =
list match {
case Nil => 0
case Cons(_, tail) => 1 + length(tail)
}
```

Most introductory material will then explain how certain functions are
*tail-recursive* if and only if the recursive call is in *tail
position* , or in other words if and only if they call themselves
recursively as their final action.

Note that looks can be deceiving: although *visually* it looks like
the recursive call in our `length`

function above is in tail position,
it is not, as we can see when we expand the execution trace of its
recursive case:

```
case Cons(_, tail) => 1 + length(tail)
// is evaluated like:
val a = 1
val b = length(tail) // not in tail position!
a + b
```

Instead, a tail-recursive `length`

looks like this:

```
def length[A](list: List[A], acc: Int = 0): Int =
list match {
case Nil => acc
case Cons(_, tail) => length(tail, acc + 1)
}
```

Tail-recursive functions tend to be less obvious for most people than their structural counterpart, so you might be wondering whether learning how to write them is worth the effort.

Well, there are at least of couple of good, practical reasons:

- The Scala compiler can run tail recursive functions in constant
stack space, without incurring a
`StackOverflowException`

if the recursion is too deep. - Some functions, when written with structural recursion, exhibit exponential complexity and are therefore prohibitively slow, and they can be made to run in linear time when written with tail recursion.

…but I have to be honest, if those were the only reasons, I probably wouldn’t bother writing a post about it.

Instead, I believe the main reason tail recursion is worth learning is
because the shape of tail-recursive functions forces us to focus on
the *essential state* needed to perform a certain computation, and
the usefulness of identifying essential state extends far beyond just
Functional Programming. For example, it is crucial in my day job when
dealing with concurrent and distributed algorithms.

This post is divided into two parts: Part I will show a recipe to
derive tail-recursive functions via *algebraic manipulation*, which is
very useful if you know the recursive structure of your computation,
but struggle to express it in tail-recursive terms. Part II will
instead show an alternative method which focuses on *state*, and gets
us closer to the skillset needed to model imperative, and ultimately
concurrent and distributed computations.

## Naive Fibonacci

The running example for this post will be a function to compute the
`n-th`

number of the Fibonacci sequence:

0, 1, 1, 2, 3, 5, 8, 13, 21, …

which is described by the following *recurrence relation*:

F(0) = 0

F(1) = 1

F(n) = F(n - 1) + F(n - 2)

We can easily translate that into Scala via structural recursion,
noting that it has two base cases, and that we will pretend that
Scala’s `Int`

cannot be negative:

```
def fib(n: Int): Int =
n match {
case 0 => 0
case 1 => 1
case n => fib(n - 1) + fib(n - 2)
}
```

Fibonacci is often used as a motivating example to learn
tail recursion because the version we’ve just written is exponential,
and starts getting pretty slow as we increase `n`

.

So, tail recursion to the rescue!

## Part I: Algebraic Derivation

Let’s start from the naive fibonacci:

```
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(): Int = ???
???
}
```

As you can see, we have left the recursive case unspecified, and
instead introduced a helper function called `go`

, which is where all
the actual recursion is going to happen.
This is a very common style since tail-recursive functions tend to
need additional parameters, which are hidden by the helper function.

In particular, each component of the recursive definition becomes a parameter, so with the recurrence relation:

F(n) = F(n - 1) + F(n - 2)

we will have parameters for `n`

, `F(n)`

, `F(n - 1)`

, and `F(n - 2)`

.

We can find appropriate names for these parameters by putting them on a number line in order:

```
F(n - 2) F(n - 1) F(n) n
-------------|---------------|---------------|------------->
secondLast last current counter
```

which means our function becomes:

```
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int = ???
???
}
```

In order to call `go`

, we need to find appropriate initial values for
its arguments: the initial value for `counter`

is 2, because
iterations 0 and 1 are already covered by the base cases.
Then, the other params are given by the Fibonacci series up to iteration 2:

0, 1, 1, …

so:

```
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int = ???
go(secondLast = 0, last = 1, current = 1, counter = 2)
}
```

Now let’s write `go`

. Our `fib`

function has to return the value of
the Fibonacci sequence at iteration `n`

, or in other words, the
`current`

value if `counter`

is equal to `n`

.

If `counter`

is not at `n`

, then we advance, so `last`

becomes
`secondLast`

, `current`

becomes `last`

, and we compute the next value
of `current`

with the definition of Fibonacci, which says the
`current`

Fibonacci number is equal to the sum of the `last`

and
`secondLast`

Fibonacci numbers.

Note that because we have replaced recursive components with arguments
to `go`

, we can compute the next value of `current`

without any
recursion, we only need to recur in tail position by passing the
updated arguments to `go`

:

```
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else {
val counterNext = counter + 1
val secondLastNext = last // we move to the right
val lastNext = current // we move to the right
val currentNext = lastNext + secondLastNext // definition of Fibonacci
go(
secondLast = secondLastNext,
last = lastNext,
current = currentNext,
counter = counterNext
)
}
go(secondLast = 0, last = 1, current = 1, counter = 2)
}
```

and we’re done! We can apply several transformations to considerably simplify this function, but before we do that, let’s recap the two ideas we’ve used to write tail-recursive functions:

- Every component of the recursive definition becomes an argument to a recursive helper.
- At the end of the recursion, one of the arguments will hold the final result.

One final observation is that the recursion is no longer *structural*,
i.e. the fact that the `current`

Fibonacci number is defined
recursively no longer corresponds to recursion in code. Instead, the
logical recursion corresponds to state updates, and the recursion in
code is only used to iterate our state transformations until we reach
the desired results. This is the reason why the recursive call can be
in tail position, but I dare say the link between recursive
definitions and state is a far more interesting aspect of this type of
function than the use of tail recursion in itself.

### Simplification

Ok, now onto simplifying our implementation: the first transformation
we can apply is *inlining*, i.e. replacing each name with its
definition.This is safe to do because even though the algorithm is
conceptually evolving variables, it’s expressed as immutable
transformations for which inlining can always be done without changing
behaviour. As an example:

```
go(..., currentNext, ...)
// but:
currentNext = lastNext + secondLastNext
// so:
go(..., lastNext + secondLastNext, ...)
// but:
lastNext = current
secondLastNext = last
// so:
go(..., current + last, ...)
```

Let’s inline all the definitions in `go`

:

```
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(secondLast: Int, last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else
go(
secondLast = last,
last = current,
current = current + last,
counter = counter + 1,
)
go(secondLast = 0, last = 1, current = 1, counter = 2)
}
```

but now we notice that `secondLast`

is redundant, because it’s updated
but never actually used to compute `current`

: `current`

is updated via
`current + last`

, and `last`

doesn’t depend on `secondLast`

.

So we can eliminate it to get to:

```
def fib(n: Int) = n match {
case 0 => 0
case 1 => 1
case n =>
def go(last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else
go(
last = current,
current = current + last
counter = counter + 1
)
go(last = 1, current = 1, counter = 2)
}
```

Now let’s eliminate the named arguments, and hoist `go`

at the top.
It’s tail recursive, and we can make sure by adding the `tailrec`

annotation:

```
import annotation.tailrec
def fib(n: Int) = {
@tailrec
def go(last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else go(current, current + last, counter + 1)
n match {
case 0 => 0
case 1 => 1
case n => go(1, 1, 2)
}
}
```

Finally, we can see that the pattern matching has redundant logic in
the first two cases, and doesn’t use the pattern in the third, so we
can replace it with an `if`

. This is our final version:

```
import annotation.tailrec
def fib(n: Int) = {
@tailrec
def go(last: Int, current: Int, counter: Int): Int =
if (counter == n) current
else go(current, current + last, counter + 1)
if (n <= 1) n else go(1, 1, 2)
}
```

## Part II: Operational Derivation

The derivation above used a lot of equational thinking, but often with tail recursion we can adopt a more operational mindset.

In fact, tail recursion can be understood as a translation from
mutable state algorithms, one where imperative thinking is more
explicit than in most imperative languages, in that for each variable
`x`

you distinguish `x`

from its updated version `x'`

(due to
immutability), and you loop via a restricted `GOTO`

(the recursive
call) instead of using `while`

.

So, let’s look at the recurrence relation again:

F(0) = 0

F(1) = 1

F(n) = F(n - 1) + F(n - 2)

This time, we start by recognising that the first two cases can be
done with an `if`

, and then we know there is going to be some
iteration in the recursive case, which we can represent with a helper.
At this point, the only information we have about the helper is its
return type:

```
def fib(n: Int) = {
def go(): Int = go()
if (n <= 1) n else go()
}
```

The guiding principle when deriving tail-recursive functions via operational thinking is figuring out which state you need to keep track of, and add each piece of state as a parameter. Tail recursive helpers have to return a result which is updated during the iteration, so we can start by keeping track of that.

```
def fib(n: Int) = {
def go(result: Int): Int = go(result)
if (n <= 1) n else go(???)
}
```

To call `go`

, we need to figure out the initial value of `result`

. We
know that the `if`

returns the result of `fib(n)`

when `n == 0`

and `n == 1`

, so the initial value of `result`

is `fib(n)`

when `n == 2`

,
which is 1, as per:

0, 1, 1, …

```
def fib(n: Int) = {
def go(result: Int): Int = go(result)
if (n <= 1) n else go(1)
}
```

Next, we have to figure out when it is possible to return `result`

,
i.e. a *termination condition*. Sometimes this can be done as a
predicate on `result`

without any additional state, but in this case
the result has to be returned once we reach the `n-th`

iteration,
which means we have to add a `counter: Int`

parameter to keep track of
which iteration we’re in.

Because the `if`

covers iterations 0 and 1, so the initial value of
`counter`

will be 2:

```
def fib(n: Int) = {
def go(result: Int, counter: Int): Int =
if (counter == n) result
else go(result, counter)
if (n <= 1) n else go(result = 1, counter = 2)
}
```

Now that we have some state with initial values, we have to figure out how to update it before recurring:

```
def fib(n: Int) = {
def go(result: Int, counter: Int): Int =
if (counter == n) result
else {
val counterNext = counter + 1
val resultNext = ???
go(resultNext, counterNext)
}
if (n <= 1) n else go(result = 1, counter = 2)
}
```

Updating `counter`

is trivial, but we don’t know what `resultNext`

should be. Well, according to the definition of Fibonacci, it’s the
sum of `fib(n - 1)`

and `fib(n - 2)`

, and we already have `fib(n - 1)`

, it’s `result`

!

To see why, consider that `resultNext`

is the `result`

, i.e. the
`fib(n)`

, of the *next* iteration, which means that `fib(n - 1)`

is
the previous value from the point of view of the next iteration, i.e
the `result`

of the current iteration. Let’s update the code:

```
def fib(n: Int) = {
def go(result: Int, counter: Int): Int =
if (counter == n) result
else {
val counterNext = counter + 1
val resultNext = result + ???
go(resultNext, counterNext)
}
if (n <= 1) n else go(result = 1, counter = 2)
}
```

However, we simply cannot compute `fib(n - 2)`

, which means we are
missing a parameter to track it, let’s call it `last: Int`

.
Understanding why this is the `last`

value can be a bit tricky, but
it’s the same idea as before: we need `fib(n - 2)`

in the *next*
iteration, which means we need `fib(n - 1)`

in this iteration, i.e.
the last value.

Now, `last`

is going to need an initial value, and it’s also going to
need to be updated on each iteration before the recursive call.
Let’s start with the initial value, which is straightforward: we need
`fib(n - 1)`

when `counter == 2`

, which is 1.
The updated `lastNext`

value represents, once again, the previous
result from the point of view of the *next* iteration, i.e. the
`result`

of the current iteration.

So we have:

```
def fib(n: Int) = {
def go(result: Int, counter: Int, last: Int): Int =
if (counter == n) result
else {
val counterNext = counter + 1
val resultNext = result + last
val lastNext = result
go(resultNext, counterNext, lastNext)
}
if (n <= 1) n else go(result = 1, counter = 2, last = 1)
}
```

And let’s apply the same refactoring as in Part I:

```
import annotation.tailrec
def fib(n: Int) = {
@tailrec
def go(result: Int, counter: Int, last: Int): Int =
if (counter == n) result
else go(result + last, counter + 1, result)
if (n <= 1) n else go(1, 2, 1)
}
```

## Conclusion

Although most didactic material covers the idea of tail recursion,
actually *writing* tail-recursive functions is often left as an
exercise to the reader.

This is a shame because, far from just being an optimisation, tail recursion is actually great at expressing tricky stateful logic, and really teaches us to think about state methodically.

So, tail recurse more, and see you next time!