This post is going to try to simplify the concepts of monads, this means that a lot of technicalities may be omitted, but I’ll be sure to attach links to more technical concepts in this post.

Now, what on earth is a monad?. A more formal definition according to this is

`A monad is a way to structure computations in terms of values and sequences of computations using those values. Monads allow the programmer to build up computations using sequential building blocks, which can themselves be sequences of computations. The monad determines how combined computations form a new computation and frees the programmer from having to code the combination manually each time it is required.`

A slightly informal definition is that a Monad is a **Wrapper** that wraps objects in such a way that you can then easily compose or combine computations of that wrapper type.

according to Mr John Degoes,

Monads are a “functional design pattern” that come up a lot when you’re writing purely-functional code. There’s a good reason for their ubiquity: monads encapsulate the essence of sequential computation.

In Scala, these monads can be used in a for comprehension to effectively compose functions. Here’s a little controversial saying,

**‘The only thing monads are relevant for, from a Scala language perspective, is the ability of being used in a for-comprehension’**

Now, how do we create a Wrapper, or a Monad. For now, let’s say that any class that implements the following methods ** map** and

**can be regarded as a Monad. Although this isn’t entirely true, as there are some laws that must be obeyed, as well as some other technicalities that are involved, but for the sake of this tutorial, we are satisfied with just a**

*flatMap*`map`

and `flatMap`

implementation.To get a practical understanding of Monads, and effectively use them in a for comprehension, Let us imagine this scenario. We go to the supermarket to buy `<strong>Nuts</strong>`

, and each type is wrapped up in a Paper `Bag`

, Now, let’s say we pick a different type of nuts each in their own separate bags and we want to calculate the price of all the different bags of nuts we have, or we throw one bag of nut into another bag of nut, or at a certain point, we decide that we want to reduce the number of nuts in a bag. One way to go about it is to manually unwrap the nuts, perform the maybe complex operations and then wrap them up again in a paper bag because, well you can’t carry that many nuts in your hands can you?

Monads, help us abstract over the wrapping and unwrapping process and just makes you think only about the values contained in the bag in a for comprehension. Enough talk, Let’s see some code.

To model our Monad, The blueprint of any monad is shown below, where Monad represents the class we want to make a Monad

Monad[A] { def map[B](f : A => B) : Monad[B] def flatMap[B](f : A => Monad[B]) : Monad[B] }

The

`map`

method is basically you exposing an API for the user of this monad to manipulate the content of the Monad, which in our case will be a`Nut`

, Note that all you can do is modify the contents of the Monad, and thereafter the result will be wrapped for you. The`flatMap`

method, on the other hand, which is more powerful exposes an API for you to manipulate the contents of the current monad for you to use to generate a new Monad, it’s a method to produce thenextcomputation from the value of thecurrentcomputation , so in our case, we can take the contents of a bag and not just modify the contents (like taking a part of the nuts we want, taking from some other types of nuts), but we then return any other kind of bag if we want containing any kind of nut.

And type A represents the value we want to wrap, in our case, Nuts is modeled thus.

sealed trait Nuts{ val quantity : Int private valprice: Double = 2.0 def total = quantity *price} case class Peanut(quantity : Int) extends Nuts case class GroundNut(quantity : Int) extends Nuts case class CashewNut(quantity : Int) extends Nuts case class MixedNuts(quantity : Int) extends Nuts

Now, let’s create our Bag that should wrap these nuts and should also be a Monad

case class Bag[A](item: A) { def map[B](f: A => B): Bag[B] = Bag(f(item)) def flatMap[B](f: A => Bag[B]): Bag[B] = f(item) } object Bag { def unit[A](v: A) =Bag(v) }

With the `<strong>map</strong>`

and `<strong>flatMap</strong>`

method, we can easily just operate on the contents of the bag of nuts without having to unwrap the bag, and then wrap it again. If you closely observe the `<strong>map</strong>`

and `<strong>flatMap</strong>`

method, they take care of the wrapping and unwrapping for us.

There’s a companion object with a

`unit`

method, which is not necessarily needed for a Monad in our context, but if you are interested, in the latter part of this post, I’ll use it to prove the laws that our Monad class actually satisfies

Now, let’s go buy 3 types of nuts, and for some weird reason, when we are at checkout, we apply a discount to the first, remove 20 pieces from the second and double the quantity of the third. Modelling this with some utility functions, we have

valcashewBag= Bag(CashewNut(40)) valgroundNutBag= Bag(GroundNut(40)) valpeanutBag= Bag(Peanut(40)) /** you pay for only 3/4 of the number of nuts * if you have more than 20 nuts */ def applyDiscount(quantity : Int) : Int = if(quantity > 20)(0.75 * quantity ).toInt else quantity def printTotal(nut : Nuts) : Unit = println(s"total is ${result.item.total}")

Now, If our Bag type wasn’t a monad, the cleanest, regular non-monadic way to do this would have been this

val newCashewBag = Bag(CashewNut(applyDiscount(cashewBag.item.quantity))) val newGroundBag = Bag(GroundNut(groundNutBag.item.quantity - 20)) val newPeanutBag = Bag(Peanut(peanutBag.item.quantity * 2)) val mixedNuts = Bag(MixedNuts(newCashewBag.item.quantity + newGroundBag.item.quantity + newPeanutBag.item.quantity)) printTotal(mixedNuts.item)

Looking at this code, we can see that it’s not easily composable, there is always a call to get the item which unwraps the bag, performing some computation and then wrapping the result in a new Bag as underlined. We can also see that we are creating unnecessary `<strong>val</strong>`

variables for each wrapping and unwrapping. Now, imagine we had over 100 kinds of nuts, this would easily have been a nightmare.

But since our Bag class is a Monad, and because we have declared methods for manipulating contents of our monad, we see that we can easily and sequentially compose our computations without having to deal with wrapping and unwrapping anything or declaring unnecessary `val`

variables.

val result = for { a <- cashewBag.map(x => CashewNut(applyDiscount(x.quantity))) b <- groundNutBag.map(x => GroundNut(x.quantity + 20)) c <- peanutBag.map(x => Peanut(x.quantity * 2)) } yield { MixedNuts(a.quantity + b.quantity + c.quantity) } printTotal(result.item)

The beauty of this is that it’s so clean and composable, the fact that the Monad has abstracted over wrapping and unwrapping its contents and just exposed the methods `map`

and `flatMap`

to deal with the contents of the wrapper.

**Conclusion**

And that’s it, Monads are really this easy, this example may be a little trivial, but let’s imagine if the scala `<strong>Future</strong>`

class wasn’t a Monad, take a moment to imagine you couldn’t call `map`

and `<strong>flatMap</strong>`

on a `<strong>Future</strong>`

, you would easily see that it would have been a nightmare, having to deal with callback hell. Also, imagine there was no `<strong>map</strong>`

and `<strong>flatMap</strong>`

methods on lists, or Options.

Behind the scenes, in a for comprehension, the compiler actually uses the `map`

and `flatMap`

method, to see that, try compiling your class with `scalac -Xprint:parse`

**Extras**

If you remember that I said for a class to be a monad, there are some laws that must be obeyed. If you are interested in this, you can continue to read further. But here are the laws really simplified using Scala.

**MONAD LAWS**

Before, we explain these laws, remember we defined a method called `unit`

which some other people call `<strong>identity</strong>`

or `<strong>point</strong>`

e.t.c. It serves just a basic purpose as described below.

The`unit`

method is basically just used to wrap an object into a wrapper of that type, in our case, we can supply a`Nut`

to the unit method, and it wraps it in a`Bag`

for us as a`Bag[Nut]`

The three monad laws according to the Haskell Wiki are informally expressed as

- Left Identity
- Right Identity
- Associativity

As we test for these laws, let’s define some functions `f`

and `g`

which are of type `A -> Monad[A]`

, these are just functions that take a type A and return a Monad of type A, we will also define an object of type `Monad[A]`

to carry out these tests on. In our case, `A `

is `Nuts`

.

valcashew=CashewNut(10) // value valcashewBag=Bag(cashew) // test Monad instance valf: CashewNut => Bag[CashewNut] = nut =>Bag(CashewNut(nut.quantity + 30)) valg: CashewNut => Bag[CashewNut] = nut =>Bag(CashewNut(nut.quantity * 50))

I am testing on `CashewNut`

, but the law will apply to all types of nuts. Now let’s see what these laws are

**LEFT IDENTITY **

So, if we have some basic value *x*, a monad instance *m* (holding some value) and functions *f* and *g *of type *Int → M[Int]*, we can write the laws as follows:

The left identity law states that a value`<strong>x</strong>`

, and our defined functions`<strong>f</strong>`

and`<strong>g</strong>`

, the following must holdunit(x).flatMap(f) == f(x)

Testing this law we have

println(Bag.unit(cashew).flatMap(f) == f(cashew)) // prints true

**RIGHT IDENTITY**

The right identity law states if we have a monad instance`<strong>X</strong>`

,calling`<strong>flatMap</strong>`

on that instance and passing a`unit`

should return the same monad instance`<strong>X</strong>`

, mathematically represented asX.flatMap(unit) ==X.

Let’s test this also

println(cashewBag.flatMap(Bag.unit) == cashewBag) // prints true

**ASSOCIATIVITY**

This law states that for a monad instance`<strong>m</strong>`

and with our already defined functions`<strong>f</strong>`

and`g`

, the following must holdm.flatMap(f).flatMap(g) == m.flatMap(x ⇒ f(x).flatMap(g))

Testing this law,

println(cashewBag.flatMap(f).flatMap(g) == cashewBag.flatMap(x ⇒ f(x).flatMap(g))) // prints true

We can clearly see that the Monad we just created is in fact, a mathematically correct Monad.