The correspondence between flatMap
and map + flatten
is well known, and in practical use flatMap
is much more common, but what are some of the use cases for using flatten
on its own?
We are going to look at two examples, using Typelevel libraries for our case study, and in the end we’ll hopefully discover something interesting about the nature of flatMap
.
Note: this post is very much not a monad tutorial, you are expected to be already familiar with monads in practical use.
Definitions recap
Let’s forget the existence of Applicative
for this post, and look at these two definitions of Monad
:
trait Monad[F[_]] {
def pure[A](a: A): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
}
trait Functor[F[_]] {
def map[A, B](fa: F[A])(f: A => B): F[B]
}
trait Monad[F[_]] extends Functor[F] {
def pure[A](a: A): F[A]
def flatten[A](ffa: F[F[A]]): F[A]
}
which are, of course, equivalent:
// definition 2) in terms of 1)
def map[A, B](fa: F[A])(f: A => B): F[B] =
flatMap(fa)(a => pure(f(a)))
def flatten[A](ffa: F[F[A]]): F[A] =
flatMap(identity)
// definition 1) in terms of 2)
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] =
flatten(map(fa)(f))
Now, definition 2) in terms of map
and flatten
(also called join
) is closer to the Category Theory definition of Monad, and therefore often preferred for theoretical explanations, but the goal of this post is different: we are going to focus on a couple of practical use cases where it makes sense to use flatten
on its own, instead of flatMap.
Example 1: Ref.modify
For our first example, we are going to look at cats.effect.Ref
, a purely functional, concurrent mutable reference which I’ve already described in a couple of talks. You can check them out if you aren’t familiar with it, but for now we will have a look at this slightly simplified API:
import cats.implicits._
import cats.effect.IO
trait Ref[A] {
def get: IO[A]
def set(a: A): IO[Unit]
def modify[B](f: A => (A, B)): IO[B]
}
get
and set
should be self explanatory, so let’s look at an example of modify
: imagine we’re modeling a race, and each sprinter needs to update their arrival position and return it.
def buggySprinter(finishLine: Ref[Int]): IO[Int] =
finishLine.get
.flatMap(pos => finishLine.set(pos + 1))
.flatMap(_ => finishLine.get)
This may seem correct at first glance, but note how the first flatMap
is not safe in the presence of concurrency: a concurrent process could change the value of the Ref
in between get
and set
(resulting in lost updates) , so we need an atomic update(f: A => A): IO[Unit]
.
This is not enough either though: a concurrent process could change the value after we update it but before we get it (resulting in an incorrect final result), so we also need the ability to return a value atomically, at which point we have rediscovered modify
:
def sprinter(finishLine: Ref[Int]): IO[Int] =
finishLine.modify { pos =>
(pos + 1, pos + 1)
}
Modify, flatten and transactionality
Depending on your concurrency background, you might be thinking of modify
as involving locks, but the way things actually work is closer to transactions: the update is retried on concurrent conflicts until it succeeds, using something called a CAS loop. Under this perspective, we can see why the flatMap
based examples were problematic: flatMap
on IO
has no concept of transactionality, which is given by the special modify
method.
To see where flatten
comes into place, I’m going to very briefly introduce a powerful pattern involving Ref
: concurrent state machines (again, watch this if you want to know more).
In a concurrent state machine, there is a set of states, transitions between these states, and actions to be run upon each transition. Moreover, the transitions are triggered concurrently, meaning that the transition function needs to follow the transactional pattern outlined above: only when the concurrent transition is successfully registered the corresponding action needs to run.
For example, let’s imagine a simple machine which will notify the user via an Http call when the lights are switched on or off.
// this is the set of states
// it's going to be held in a `Ref`
sealed trait LightSwitch
case object On extends LightSwitch
case object Off extends LightSwitch
// imagine something richer
type HttpStatus = Int
// these are our actions
trait Notifications {
// e.g. via http call
def notify(msg: String): IO[HttpStatus]
}
// our finite state machine
trait FSM {
def toggle: IO[HttpStatus]
}
So, how do we implement toggle
? Remember that Ref.modify
allows us to atomically change the state and return a value, and IO
s are values, which means we can return the action itself.
modify(f: A => (A, B): IO[B]
modify(f: LightSwitch => (LightSwitch, IO[HttpStatus])): IO[IO[HttpStatus]]
It’s important to understand the semantics of that IO[IO[HttpStatus]]
: when the outer IO
runs, it will try to change the current state of FSM
, possibly retrying multiple times in case of conflict, but without running the inner IO[HttpStatus]
, which is merely returned at the end.
Once the outer IO
returns, it means that the state of FSM
has been successfully changed, and it is now possible to run the action represented by the inner IO[HttpStatus]
. In other words, we now need to go from IO[IO[HttpStatus]]
to IO[HttpStatus]
, which we can do with… flatten
.
The full code looks like this:
def fsm(states: Ref[LightSwitch], actions: Notifications): FSM = new FSM {
def toggle: IO[HttpStatus] = states.modify {
case Off => (On, actions.notify("The light has been turned on"))
case On => (Off, actions.notify("The light has been turned off"))
}.flatten
}
So we have our first case of standalone flatten
, which happens when there is some notion of transactionality : you can think of F[F[A]]
as a program that returns another program, where the first program is transactional and the second is not.
We used Ref.modify
, but this principle extends to other scenarios as well:
- The Doobie library exposes a
ConnectionIO[A]
type that represents a database transaction. So you would have aConnectionIO[IO[A]]
,transact
to get toIO[IO[A]]
, thenflatten
. - Haskell’s Software Transactional Memory is based on an
STM a
type that represents a transactional concurrent computation (similar to a multivariableRef
). So you would have anSTM (IO a)
,atomically
to get toIO (IO a)
, thenjoin
(Haskell’s version offlatten
).
Example 2: JSON decoding
For our second example we are going to look at JSON decoding with Circe, focusing on a specific issue my team encountered recently. We have some simple JSON, representing our two different types of users:
import io.circe._
import io.circe.parser._
import io.circe.generic.semiauto._
def json(s: String): Json = parse(s).getOrElse(Json.Null)
def named = json {
"""
{
"named" : {
"name": "Dotty",
"surname" : "McDotFace"
}
}
"""
}
def unnamed = json {
"""
{
"unnamed" : {
"id" : 13355
}
}
"""
}
which maps to the following ADT:
sealed trait User
case class Unnamed(id: Long) extends User
case class Named(name: String, surname: String) extends User
So let’s go ahead and write a Decoder
for it
def unnamedDec: Decoder[User] =
deriveDecoder[Unnamed]
.prepare(_.downField("unnamed"))
.widen[User]
def namedDec: Decoder[User] =
deriveDecoder[Named]
.prepare(_.downField("named"))
.widen[User]
def userDec = unnamedDec orElse namedDec
userDec.decodeJson(named)
// res2: Decoder.Result[User] = Right(Named("Dotty", "McDotFace"))
userDec.decodeJson(unnamed)
// res3: Decoder.Result[User] = Right(Unnamed(13355L))
A few points about the code above:
deriveDecoder
automatically derives a decoder for the individual cases.prepare
modifies the input json before feeding it to the decoder. In our case we need to access the “unnamed” and “named” json objects before decoding the corresponding data.widen
changes aDecoder[B]
into aDecoder[A]
whenB <: A
. It’s the explicit equivalent of covariance.decA orElse decB
will trydecA
, and fallback ondecB
ifdecA
fails.
So far this works great, but look at what happens when we send an incorrect unnamed user:
def incorrectUnnamed: Json = json {
"""
{
"unnamed" : {
"id" : "not a long"
}
}
"""
}
userDec.decodeJson(incorrectUnnamed)
// res4: Decoder.Result[User] = Left(
// DecodingFailure(Attempt to decode value on failed cursor, List(DownField(named)))
// )
As you can see, Circe
does give back a nice error message, but unfortunately it’s coming from the wrong branch. We know that if the json contains unnamed
, it’s an unnamed user, but circe does not: it sees a failure and falls back to the named user decoder with orElse
, which obviously fails to decode as well, at which point you get the error from the last branch.
In general, this is ok, but in our specific use case it was a source of pain for the users of our API, so we wanted to report errors pertinent to the ADT case they were sending to us (assuming they got at least the “named/unnamed” tag right).
flatMap and orElse
To see where the problem lies, it’s easy to think of the code above as being made of these four components (not 100% true in Circe
terms, but the differences are irrelevant):
- an “unnamed” accessor
Decoder
, of typeunnamedField: Decoder[Json]
- a
Decoder
for unnamed users, of typeunnamedData: Json => Decoder[User]
- a “named” accessor
Decoder
, of typenamedField: Decoder[Json]
- a
Decoder
for named users, of typenamedData: Json => Decoder[User]
where the whole decoder is
unnamedField.flatMap(unnamedData) orElse namedField.flatMap(namedData)
if you look for example at unnamedField.flatMap(unnamedData)
, you can see that there are two possible sources of error; one is in unnamedField
, which means the tag is not “unnamed”, and one is in unnamedData
, which means that the data format is wrong.
Crucially, we want orElse
to only operate on the first source, so we have to separate them. One way to do that would be:
unnamedField.orElse(namedField).flatMap { json =>
unnamedData(json) orElse namedData(json)
}
But that’s not ideal: first of all there is repetition of orElse
(in our actual scenario there were many more cases), but also it’s actually weird to have to define unnamedField
and unnamedData
separately.
We somehow want to keep them together, but without triggering the second source of errors until after orElse
, or in other words, we want to return the second Decoder
program without running it (and therefore triggering its errors).
Again, this corresponds to Decoder[Decoder[User]]
, which we can get by changing flatMap
to map
:
unnamedField.map(unnamedData) orElse namedField.map(namedData)
orElse
will give us another Decoder[Decoder[User]]
, and we can now run the inner decoder with… flatten
.
And this is our second case of standalone flatten
, which happens when we want to interleave another operation (in this case orElse
) in between the map
and flatten
parts of flatMap
.
The full code contains some Circe
details which aren’t super relevant conceptually, but I’m leaving it here for the interested readers. Note the correct, informative error trail at the end.
implicit class Tagger[A](d: Decoder[A]) {
def tag(accessor: String): Decoder[Decoder[A]] = Decoder.instance { inputJson =>
inputJson.downField(accessor) match {
case innerJson: HCursor => Right(innerJson)
case err: FailedCursor => Left(DecodingFailure("Failed cursor", err.history))
}
}.map(outJson => Decoder.instance(_ => d(outJson)))
}
def betterDec: Decoder[User] =
deriveDecoder[Unnamed]
.widen[User]
.tag("unnamed")
.orElse {
deriveDecoder[Named]
.widen[User]
.tag("named")
}.flatten
betterDec.decodeJson(named)
// res5: Decoder.Result[User] = Right(Named("Dotty", "McDotFace"))
betterDec.decodeJson(unnamed)
// res6: Decoder.Result[User] = Right(Unnamed(13355L))
betterDec.decodeJson(incorrectUnnamed)
// res7: Decoder.Result[User] = Left(
// DecodingFailure(Long, List(DownField(id), DownField(unnamed)))
// )
Conclusion: the nature of flatMap
Let’s look again at the signature of flatMap
:
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
you can look at F[A]
as a program that returns A
s, and at flatMap
as a mode of composition that represents running two programs in sequence, using the result of the first to decide the shape of the second (A => F[B]
means exactly “F[B]
depends on A
”).
But as we’ve seen, there are actually two components at play, map
and flatten
, so we can now refine our intuition: map
represents running the first program and using its result to decide the shape of the second, and flatten
then actually runs the second program.
Most of the time, these two things happen as a unit (hence why flatMap
is more prominent), but not all the time, either because we want to take the decision about which program to run in a different way (in the case of modify
, transactionally), or because we need to run other operations first (like orElse
).
Slightly more formally, monads are about substitution (map
) followed by renormalisation (flatten
) and, as we’ve seen, sometimes you need to manipulate the non normalised form F[F[A]]
on its own.
Hope you enjoyed this post, and see you next time!