During our training and mentoring, we often hear students say “I’m not sure what to do here” when faced with trying to choose what function to call in order to solve their problem. First, it’s completely normal; nobody instantly knows the answer. To try to help, one piece of advice we’ve found to be useful is in the form of a heuristic:

The solution is usually either a map, flatMap, or fold.*

In this post we’ll be discussing the power of folds: how they are used to summarize a data structure, how summarization can be decomposed and abstracted, and how that abstraction can be leveraged to provide better performance in a generic way.

*And if it isn’t one of those, it’s probably traverse.

Decomposing Summarization

If a List contains numbers, we can compute the sum of the elements using a fold:

List(1, 2, 3).foldLeft(0)((sum, i) => sum + i)
// res0: Int = 6

foldLeft has a rather generic signature, so let’s factor out the necessary knowledge to summarize our List of Ints:

  • We need a starting value to initialize our summary, in case the list is empty: ifEmpty: Int; and
  • We need a way to combine the previous summary to the next element: combine: (Int, Int) => Int.

Let’s write a helper function summarize to name these parts, and just delegate to foldLeft as the implementation:

def summarize(is: List[Int], ifEmpty: Int, combine: (Int, Int) => Int): Int =
is.foldLeft(ifEmpty)((sum, i) => combine(sum, i))

summarize(List(1, 2, 3), 0, _ + _)
// res1: Int = 6

But how could we summarize any list, not just a list of numbers? Let’s allow the caller to choose the element type, and adjust the ifEmpty and combine helpers:

def summarize2[A](as: List[A], ifEmpty: A, combine: (A, A) => A): A =
as.foldLeft(ifEmpty)((sum, a) => combine(sum, a))

summarize2[Int](List(1, 2, 3), 0, _ + _)
// res2: Int = 6

Now the caller can summarize a list any way they want. But it’s annoying to have to supply ifEmpty and combine for every call; those parameters will (almost) always be the same for every type A.

Capturing Common Behavior in a Typeclass

Let’s factor those two helpers into a typeclass:

// n.b. exactly the same as cats.Monoid
trait Monoid[A] {
def empty(): A
def combine(a1: A, a2: A): A
}

And then implicitly provide the Monoid containing our empty and combine helpers:

def summarize3[A](as: List[A])(implicit M: Monoid[A]): A =
as.foldLeft(M.empty)((sum, a) => M.combine(sum, a))

Let’s create an Monoid[Int] instance so we can use it in summarize3:

implicit val intMonoid: Monoid[Int] =
new Monoid[Int] {
def empty(): Int = 0
def combine(i1: Int, d2: Int): Int = i1 + d2
}

Finally we’re back where we started:

summarize3(List(1, 2, 3))
// res3: Int = 6

summarize3 will now work for lists of any type A, as long as A has a Monoid[A].

Deriving foldMap

But what if we have a list of some type that doesn’t, or can’t, have a Monoid? To compute a summary for it, we’ll need to transform our list into a type that does have a Monoid. Is there a way to do this?

There is, and it’s a common strategy in functional programming: we ask for help from the caller, since foldMap itself can’t know what to do. If the caller provides us with a function A => B, and B has a Monoid, then we can summarize. Aha!

This function is usually called foldMap, because we’re both fold-ing (the list) and map-ping the elements (the A => B function):

// Summarize a List[A] as a B, if B has a Monoid.
def foldMap[A, B](as: List[A])(f: A => B)(implicit M: Monoid[B]): B =
as.foldLeft(M.empty)((b, a) => M.combine(b, f(a)))

If A already has a Monoid, then we can summarize using foldMap by not transforming the data. That is, we transform it with the identity function:

def sum[A: Monoid](as: List[A]): A =
foldMap(as)(identity)

sum(List(1, 2, 3, 4, 5))
// res4: Int = 15

It’s not very exciting, but we can count the length of a list too (implicitly using the Monoid[Int] to “increment”):

def count[A](as: List[A]): Int =
foldMap(as)(_ => 1)
// ^
// increment count by 1
// for each element

count(List(1, 2, 3, 4, 5))
// res5: Int = 5

Computing the Mean with foldMap

What can you do with the sum and the count of a list? Compute the mean!

val l = List(1, 2, 3, 4, 5)
sum(l)
// res6: Int = 15
count(l)
// res7: Int = 5

def mean(is: List[Int]): Double =
sum(is).toDouble / count(is)

mean(l)
// res8: Double = 3.0

It works!

Computing the Mean with foldMap: Improved!

For the performance-minded, there’s an issue with the computation of the mean: it processes the list twice, once for the sum and once for the count. Could we compute the mean in only one pass?

One way to think about how we could do this would be to fill in, as best we can, the parameters to one call to foldMap, in order to compute the answer we’re looking for. That is, we need foldMap to return the sum AND the count:

def onePassMean(is: List[Int]): Double = {
// need sum AND count from foldMap
val (sum, count) =
foldMap(is)(i => ???)

sum.toDouble / count
}

If foldMap is to return an (Int, Int) tuple, the A => B mapping function–the ??? above–needs to return (Int, Int), and foldMap also requires a Monoid[(Int, Int)]

def onePassMean(is: List[Int]): Double = {
val (sum: Int, count: Int) =
foldMap(is)(i => (???): (Int, Int))
// ^
// needs to produce (Int, Int)

sum.toDouble / count
}
// error: could not find implicit value for parameter M: repl.Session.App.Monoid[(Int, Int)]
// foldMap(is)(i => (???): (Int, Int))
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We know the values of components of the tuple: the first needs to be the sum and the second needs to be the count. When foldMap uses the Monoid[(Int, Int)] to combine tuples, it needs to work like this:

val sumAndCount = (10, 4)
// sumAndCount: (Int, Int) = (10, 4)
// ^ ^
// ^ previous count
// previous sum
val toAdd = (5, 1)
// toAdd: (Int, Int) = (5, 1)
// ^ ^
// ^ increment the count by 1
// increment the sum by 5
val totals = (10 + 5, 4 + 1)
// totals: (Int, Int) = (15, 5)

This generalizes to any tuple: combining two tuples produces one tuple with combined components. Sounds familiar? Here’s the Monoid we need! (if there exists a Monoid for the each component of the tuple)

implicit def tuple2Monoid[A, B](implicit MA: Monoid[A], MB: Monoid[B]): Monoid[(A, B)] =
new Monoid[(A, B)] {
// the empty tuple is a tuple of empty values from each Monoid
def empty(): (A, B) = (MA.empty, MB.empty)

// combine A fields via MA, combine B fields via MB
def combine(ab1: (A, B), ab2: (A, B)): (A, B) =
(ab1, ab2) match {
case ((a1, b1), (a2, b2)) =>
(MA.combine(a1, a2), MB.combine(b1, b2))
}
}

tuple2Monoid.combine((10, 4), (5, 1))
// res11: (Int, Int) = (15, 5)

Now we can compute our mean with one pass over the data:

def onePassMean(is: List[Int]): Double = {
val (sum, count) =
foldMap(is)(i => (i, 1))
// ^ ^
// ^ increment count by 1
// increment sum by i
sum.toDouble / count
}

onePassMean(l)
// res12: Double = 3.0

Summary

  • Summarizing a List is a fold.
  • The initial value and combining operation of elements for a summary can be abstracted over by a Monoid.
  • foldMap summarizes a List by requiring each element be able to be transformed into something that has a Monoid. foldMap is available in cats as part of the Foldable typeclass, which models foldable (“summarizable”) types like List.
  • We can derive a typeclass instance for tuples of monoids, which then gives us a way to combine values “in parallel”. Typeclass instances for tuples are included in cats, and are usually imported into scope via the wildcard import cats.implicits._.
// The same example as above, but using cats.
import cats.implicits._

List(1, 2, 3, 4, 5).foldMap(i => (i, 1))
// res13: (Int, Int) = (15, 5)

Further reading: