The correspondence between flatMap
and map + flatten
is well
known, and in practical use flatMap
is much more common, but what
are some of the use cases for using flatten
on its own?
We are going to look at two examples, using Typelevel libraries
for our case study, and in the end we’ll hopefully discover
something interesting about the nature of flatMap
.
Note: this post is very much not a monad tutorial, you are expected to be already familiar with monads in practical use.
Definitions recap
Let’s forget the existence of Applicative
for this post, and look at
these two definitions of Monad
:
trait Monad[F[_]] {
def pure[A](a: A): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
}
trait Functor[F[_]] {
def map[A, B](fa: F[A])(f: A => B): F[B]
}
trait Monad[F[_]] extends Functor[F] {
def pure[A](a: A): F[A]
def flatten[A](ffa: F[F[A]]): F[A]
}
which are, of course, equivalent:
// definition 2) in terms of 1)
def map[A, B](fa: F[A])(f: A => B): F[B] =
flatMap(fa)(a => pure(f(a)))
def flatten[A](ffa: F[F[A]]): F[A] =
flatMap(identity)
// definition 1) in terms of 2)
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] =
flatten(map(fa)(f))
Now, definition 2) in terms of map
and flatten
(also called
join
) is closer to the Category Theory definition of Monad, and
therefore often preferred for theoretical explanations, but the goal
of this post is different: we are going to focus on a couple of
practical use cases where it makes sense to use flatten
on its
own, instead of flatMap.
Example 1: Ref.modify
For our first example, we are going to look at cats.effect.Ref
, a
purely functional, concurrent mutable reference which I’ve already
described in a couple of
talks.
You can check them out if you aren’t familiar with it, but for now we
will have a look at this slightly simplified API:
import cats.implicits._
import cats.effect.IO
trait Ref[A] {
def get: IO[A]
def set(a: A): IO[Unit]
def modify[B](f: A => (A, B)): IO[B]
}
get
and set
should be self explanatory, so let’s look at an
example of modify
: imagine we’re modeling a race, and each sprinter
needs to update their arrival position and return it.
def buggySprinter(finishLine: Ref[Int]): IO[Int] =
finishLine.get
.flatMap(pos => finishLine.set(pos + 1))
.flatMap(_ => finishLine.get)
This may seem correct at first glance, but note how the first
flatMap
is not safe in the presence of concurrency: a concurrent
process could change the value of the Ref
in between get
and set
(resulting in lost updates) , so we need an atomic update(f: A => A): IO[Unit]
.
This is not enough either though: a concurrent process could change
the value after we update it but before we get it (resulting in an
incorrect final result), so we also need the ability to return a value
atomically, at which point we have rediscovered modify
:
def sprinter(finishLine: Ref[Int]): IO[Int] =
finishLine.modify { pos =>
(pos + 1, pos + 1)
}
Modify, flatten and transactionality
Depending on your concurrency background, you might be thinking of
modify
as involving locks, but the way things actually
work is closer to transactions: the update is retried on concurrent
conflicts until it succeeds, using something called a CAS loop.
Under this perspective, we can see why the flatMap
based examples
were problematic: flatMap
on IO
has no concept of
transactionality, which is given by the special modify
method.
To see where flatten
comes into place, I’m going to very briefly
introduce a powerful pattern involving Ref
: concurrent state
machines (again, watch
this if you want to
know more).
In a concurrent state machine, there is a set of states,
transitions between these states, and actions to be run upon each
transition. Moreover, the transitions are triggered concurrently,
meaning that the transition function needs to follow the transactional
pattern outlined above: only when the concurrent transition is
successfully registered the corresponding action needs to run.
For example, let’s imagine a simple machine which will notify the user via an Http call when the lights are switched on or off.
// this is the set of states
// it's going to be held in a `Ref`
sealed trait LightSwitch
case object On extends LightSwitch
case object Off extends LightSwitch
// imagine something richer
type HttpStatus = Int
// these are our actions
trait Notifications {
// e.g. via http call
def notify(msg: String): IO[HttpStatus]
}
// our finite state machine
trait FSM {
def toggle: IO[HttpStatus]
}
So, how do we implement toggle
? Remember that Ref.modify
allows us to atomically change the state and return a value, and IO
s
are values, which means we can return the action itself.
modify(f: A => (A, B): IO[B]
modify(f: LightSwitch => (LightSwitch, IO[HttpStatus])): IO[IO[HttpStatus]]
It’s important to understand the semantics of that IO[IO[HttpStatus]]
:
when the outer IO
runs, it will try to change the current state of FSM
,
possibly retrying multiple times in case of conflict, but without running the inner
IO[HttpStatus]
, which is merely returned at the end.
Once the outer IO
returns, it means that the state of FSM
has been
successfully changed, and it is now possible to run the action
represented by the inner IO[HttpStatus]
. In other words, we now need
to go from IO[IO[HttpStatus]]
to IO[HttpStatus]
, which we can do
with… flatten
.
The full code looks like this:
def fsm(states: Ref[LightSwitch], actions: Notifications): FSM = new FSM {
def toggle: IO[HttpStatus] = states.modify {
case Off => (On, actions.notify("The light has been turned on"))
case On => (Off, actions.notify("The light has been turned off"))
}.flatten
}
So we have our first case of standalone flatten
, which happens when
there is some notion of transactionality : you can think of
F[F[A]]
as a program that returns another program, where the first
program is transactional and the second is not.
We used Ref.modify
, but this principle extends to other scenarios as well:
- The Doobie library exposes a
ConnectionIO[A]
type that represents a database transaction. So you would have aConnectionIO[IO[A]]
,transact
to get toIO[IO[A]]
, thenflatten
. - Haskell’s Software Transactional Memory is based on an
STM a
type that represents a transactional concurrent computation (similar to a multivariableRef
). So you would have anSTM (IO a)
,atomically
to get toIO (IO a)
, thenjoin
(Haskell’s version offlatten
).
Example 2: JSON decoding
For our second example we are going to look at JSON decoding with Circe, focusing on a specific issue my team encountered recently. We have some simple JSON, representing our two different types of users:
import io.circe._
import io.circe.parser._
import io.circe.generic.semiauto._
def json(s: String): Json = parse(s).getOrElse(Json.Null)
def named = json {
"""
{
"named" : {
"name": "Dotty",
"surname" : "McDotFace"
}
}
"""
}
def unnamed = json {
"""
{
"unnamed" : {
"id" : 13355
}
}
"""
}
which maps to the following ADT:
sealed trait User
case class Unnamed(id: Long) extends User
case class Named(name: String, surname: String) extends User
So let’s go ahead and write a Decoder
for it
def unnamedDec: Decoder[User] =
deriveDecoder[Unnamed]
.prepare(_.downField("unnamed"))
.widen[User]
def namedDec: Decoder[User] =
deriveDecoder[Named]
.prepare(_.downField("named"))
.widen[User]
def userDec = unnamedDec orElse namedDec
userDec.decodeJson(named)
// res2: Decoder.Result[User] = Right(Named("Dotty", "McDotFace"))
userDec.decodeJson(unnamed)
// res3: Decoder.Result[User] = Right(Unnamed(13355L))
A few points about the code above:
deriveDecoder
automatically derives a decoder for the individual cases.prepare
modifies the input json before feeding it to the decoder. In our case we need to access the “unnamed” and “named” json objects before decoding the corresponding data.widen
changes aDecoder[B]
into aDecoder[A]
whenB <: A
. It’s the explicit equivalent of covariance.decA orElse decB
will trydecA
, and fallback ondecB
ifdecA
fails.
So far this works great, but look at what happens when we send an incorrect unnamed user:
def incorrectUnnamed: Json = json {
"""
{
"unnamed" : {
"id" : "not a long"
}
}
"""
}
userDec.decodeJson(incorrectUnnamed)
// res4: Decoder.Result[User] = Left(
// DecodingFailure(Attempt to decode value on failed cursor, List(DownField(named)))
// )
As you can see, Circe
does give back a nice error message, but
unfortunately it’s coming from the wrong branch.
We know that if the json contains unnamed
, it’s an unnamed user, but
circe does not: it sees a failure and falls back to the named user
decoder with orElse
, which obviously fails to decode as well, at
which point you get the error from the last branch.
In general, this is ok, but in our specific use case it was a source
of pain for the users of our API, so we wanted to report errors
pertinent to the ADT case they were sending to us (assuming they got
at least the “named/unnamed” tag right).
flatMap and orElse
To see where the problem lies, it’s easy to think of the code above as
being made of these four components (not 100% true in Circe
terms,
but the differences are irrelevant):
- an “unnamed” accessor
Decoder
, of typeunnamedField: Decoder[Json]
- a
Decoder
for unnamed users, of typeunnamedData: Json => Decoder[User]
- a “named” accessor
Decoder
, of typenamedField: Decoder[Json]
- a
Decoder
for named users, of typenamedData: Json => Decoder[User]
where the whole decoder is
unnamedField.flatMap(unnamedData) orElse namedField.flatMap(namedData)
if you look for example at unnamedField.flatMap(unnamedData)
, you
can see that there are two possible sources of error; one is in
unnamedField
, which means the tag is not “unnamed”, and one is in
unnamedData
, which means that the data format is wrong.
Crucially, we want orElse
to only operate on the first source, so we
have to separate them. One way to do that would be:
unnamedField.orElse(namedField).flatMap { json =>
unnamedData(json) orElse namedData(json)
}
But that’s not ideal: first of all there is repetition of orElse
(in
our actual scenario there were many more cases), but also it’s
actually weird to have to define unnamedField
and unnamedData
separately.
We somehow want to keep them together, but without triggering the
second source of errors until after orElse
, or in other words, we
want to return the second Decoder
program without running it
(and therefore triggering its errors).
Again, this corresponds to Decoder[Decoder[User]]
, which we can get
by changing flatMap
to map
:
unnamedField.map(unnamedData) orElse namedField.map(namedData)
orElse
will give us another Decoder[Decoder[User]]
, and we can now
run the inner decoder with… flatten
.
And this is our second case of standalone flatten
, which happens
when we want to interleave another operation (in this case orElse
)
in between the map
and flatten
parts of flatMap
.
The full code contains some Circe
details which aren’t super
relevant conceptually, but I’m leaving it here for the interested
readers. Note the correct, informative error trail at the end.
implicit class Tagger[A](d: Decoder[A]) {
def tag(accessor: String): Decoder[Decoder[A]] = Decoder.instance { inputJson =>
inputJson.downField(accessor) match {
case innerJson: HCursor => Right(innerJson)
case err: FailedCursor => Left(DecodingFailure("Failed cursor", err.history))
}
}.map(outJson => Decoder.instance(_ => d(outJson)))
}
def betterDec: Decoder[User] =
deriveDecoder[Unnamed]
.widen[User]
.tag("unnamed")
.orElse {
deriveDecoder[Named]
.widen[User]
.tag("named")
}.flatten
betterDec.decodeJson(named)
// res5: Decoder.Result[User] = Right(Named("Dotty", "McDotFace"))
betterDec.decodeJson(unnamed)
// res6: Decoder.Result[User] = Right(Unnamed(13355L))
betterDec.decodeJson(incorrectUnnamed)
// res7: Decoder.Result[User] = Left(
// DecodingFailure(Long, List(DownField(id), DownField(unnamed)))
// )
Conclusion: the nature of flatMap
Let’s look again at the signature of flatMap
:
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
you can look at F[A]
as a program that returns A
s, and at flatMap
as a mode of composition that represents running two programs in
sequence, using the result of the first to decide the shape of the
second (A => F[B]
means exactly “F[B]
depends on A
”).
But as we’ve seen, there are actually two components at play, map
and flatten
, so we can now refine our intuition: map
represents
running the first program and using its result to decide the shape of
the second, and flatten
then actually runs the second program.
Most of the time, these two things happen as a unit (hence why
flatMap
is more prominent), but not all the time, either because we
want to take the decision about which program to run in a different
way (in the case of modify
, transactionally), or because we need to
run other operations first (like orElse
).
Slightly more formally, monads are about substitution (map
) followed
by renormalisation (flatten
) and, as we’ve seen, sometimes you
need to manipulate the non normalised form F[F[A]]
on its own.
Hope you enjoyed this post, and see you next time!