# Lecture 6: More Scala: Map, Filter, and Fold
Originally by Sriram Sankaranarayanan 

Modified by Ravi Mangal 

Last Modified: Feb 10, 2025.

## Higher-order functions over collections in Scala

The functional programming style eschews loops and replaces it with tail recursive functions that express the logic in a simple fashion, without carrying extra state. Another mechanism that we will study now is that of in-built higher-order functions such as "map", "filter" and "fold" applicable to various collections. Typically, these functions will be used to manipulate Lists of objects. But they also apply to other collections in Scala such as Maps, Vectors, Sets, etc..


Before we begin with these higher-order functions, let us study different ways to write functions in Scala, including a convenient notation for anonymous functions.

Let us start with a function to multiply every element of a list by two.

We will now multiply every element of a list by 2.

In [8]:
def multiplyEachEltByTwo(lst: List[Int], accList: List[Int] = Nil): List[Int] = lst match {
 case Nil => accList
 case hd::tail => {
 val newAccList = accList ++ List(2 * hd) // We add 2 * hd at the end why?
 multiplyEachEltByTwo(tail, newAccList)
 }
} 

defined [32mfunction[39m [36mmultiplyEachEltByTwo[39m

Next, we would like to remove all the even elements from a list, returning a new list with just the odd elements.

In [9]:
def removeEvenNumbers(lst: List[Int], accList: List[Int] = Nil): List[Int] = lst match {
 case Nil => accList
 case hd::tail => {
 val newAccList = if (hd %2 == 0) { accList } else { accList ++ List(hd) }
 removeEvenNumbers(tail, newAccList)
 }
} 

defined [32mfunction[39m [36mremoveEvenNumbers[39m

In [None]:
(List[Int], List[Int]) => List[Int]

Finally, we wish to sum up the elements of the list

In [10]:
def sumOfList(lst: List[Int], sum: Int = 0): Int = lst match {
 case Nil => sum
 case hd::tail => sumOfList(tail, sum + hd)
}

defined [32mfunction[39m [36msumOfList[39m

Finally, we can write our main function:

In [11]:
def processList(lst: List[Int]): Int = {
 // Multiply by two
 val lst1 = removeEvenNumbers(lst)
 val lst2 = multiplyEachEltByTwo(lst1)
 sumOfList(lst2)
}

defined [32mfunction[39m [36mprocessList[39m

In [12]:
removeEvenNumbers(List(1,3,4,5,6,7,8))
multiplyEachEltByTwo(List(1,2,3,4))
sumOfList ((1 to 100).toList)
processList((1 to 20).toList)

[36mres12_0[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m1[39m, [32m3[39m, [32m5[39m, [32m7[39m)
[36mres12_1[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m2[39m, [32m4[39m, [32m6[39m, [32m8[39m)
[36mres12_2[39m: [32mInt[39m = [32m5050[39m
[36mres12_3[39m: [32mInt[39m = [32m200[39m

This is somewhat painful since we need to write three separate functions to do the job. Is there something better we can do? Yes! let us recognize three patterns of operations we would like to achieve:

- Map: apply a function f to every element of a list.
- Filter: keep just those elements of the list that satisfy a "predicate"
- Fold (or reduce): perform an accumulative operation to every element of the list.

### Anonymous Functions In Scala 

Before we look closer at these operations, let us first familiarize ourselves with `anonymous` functions in Scala.
Often it is cumbersome to define functions by name where we would like to pass a function. Therefore, we will use "anonymous" functions.

In [13]:
def multiplyByTwo(x: Int): Int = x * 2

defined [32mfunction[39m [36mmultiplyByTwo[39m

Here are two other ways to write the same thing.

In [3]:
val f : Int => Int = x => x * 2

[36mf[39m: [32mInt[39m => [32mInt[39m = ammonite.$sess.cmd3$Helper$$Lambda$2385/0x0000000100bef840@6968216e

In [5]:
def f_prime(m:Int => Int, y:Int): Int = m(y) 

defined [32mfunction[39m [36mf_prime[39m

In [6]:
f_prime(f,2)

[36mres6[39m: [32mInt[39m = [32m4[39m

f is bound to a function that takes in an argument `x` and returns `x * 2`. You can pass the expression ` (x) => x * 2 ` in any context you wish without giving it a name as we will see. Here is another succint version:

In [15]:
val f2: Int => Int = _ * 2

[36mf2[39m: [32mInt[39m => [32mInt[39m = ammonite.$sess.cmd15$Helper$$Lambda$2406/0x0000000100824840@4774c76

The `_` here is simply the first argument. Often it is important to specify the type of an argument in an anonymous function.

In [16]:
val g = (x: String) => x + x // OK: Scala infers the type of x + x from that of x and the type of g is inferred.

[36mg[39m: [32mString[39m => [32mString[39m = ammonite.$sess.cmd16$Helper$$Lambda$2408/0x0000000100821840@20a2a4db

In [17]:
val g: String => String = x => x + x // OK: Scala infers the typeof x from the type given to g

[36mg[39m: [32mString[39m => [32mString[39m = ammonite.$sess.cmd17$Helper$$Lambda$2410/0x000000010081c840@70ea23f8

In [7]:
val g2 = x => x + x // BAD: Scala has no way of knowing what x is. It can be a String, Int, Double, ...

cmd8.sc:1: missing parameter type
val g2 = x => x + x // BAD: Scala has no way of knowing what x is. It can be a String, Int, Double, ...
 ^
Compilation Failed

Anonymous functions can take multiple arguments.

In [2]:
val addFun = (x: Int, y:Int) => x + y

[36maddFun[39m: ([32mInt[39m, [32mInt[39m) => [32mInt[39m = ammonite.$sess.cmd2$Helper$$Lambda$2383/0x0000000100bed040@3659f16a

In [19]:
val addFun = (x: (Int, Int)) => x._1 + x._2

[36maddFun[39m: (([32mInt[39m, [32mInt[39m)) => [32mInt[39m = ammonite.$sess.cmd19$Helper$$Lambda$2422/0x000000010080b840@44e77580

In [20]:
val addFun: (Int, Int) => Int = _ + _ // First _ is the first argument and second _ is the second argument.

[36maddFun[39m: ([32mInt[39m, [32mInt[39m) => [32mInt[39m = ammonite.$sess.cmd20$Helper$$Lambda$2424/0x0000000100809840@214ef5f1

Last but not least in a case pattern matching setup, you can define an anonymous function without the match statement.

In [21]:
sealed trait MyList 
case object MyNil extends MyList
case class MyCons(x: Int, l: MyList) extends MyList

defined [32mtrait[39m [36mMyList[39m
defined [32mobject[39m [36mMyNil[39m
defined [32mclass[39m [36mMyCons[39m

In [22]:
val anonIsEmptyFun: MyList => Boolean = (x) => { x match {
 case MyNil => true
 case MyCons(_, _) => false
}}

[36manonIsEmptyFun[39m: [32mMyList[39m => [32mBoolean[39m = ammonite.$sess.cmd22$Helper$$Lambda$2504/0x0000000100c7b840@2b6cc14f

In [23]:
val anonIsEmptyFun: MyList => Boolean = {
 case MyNil => true
 case MyCons(_, _) => false
}

[36manonIsEmptyFun[39m: [32mMyList[39m => [32mBoolean[39m = ammonite.$sess.cmd23$Helper$$Lambda$2506/0x0000000100c7f840@6fde4e2e

In other words, when you have the pattern 
~~~
 (x : Type) => x match {
 case .. =>
 case .. => 
 ...
 }
~~~
You can instead simply say 

~~~
{ 
 case .. => 
 case .. => 
 ...
 }
~~~

without saying `(x : Type) => x match`.

## Map, Filter and Fold (Reduce) Operations

In many languages, the use of for-loops/while loops to iterate is replaced by higher-order operations on data structures such as `map`, `filter` and `fold`. In this lecture, we provide a brief overview with some examples. We show how many varieties of loops or equivalently recursion, can be systematically replaced by these operations.


## Map operation

The idea of a map operation is to apply a function $f$ to every member of a container (eg., list, array, map, etc.) and return a new container.

### Example 1

We have a list `List(1, 3, 4, 5, 6, 110, 12, 2)`. We wish to compute the square of each element in the list and make a new list with the result.

In [24]:
def recursivelySquareEachElt(l: List[Int], acc: List[Int] = Nil): List[Int] = l match {
 case Nil => acc.reverse
 case hd::tl => recursivelySquareEachElt(tl, (hd*hd)::acc)
}

defined [32mfunction[39m [36mrecursivelySquareEachElt[39m

In [25]:
recursivelySquareEachElt(List(10))

[36mres25[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m100[39m)

In [26]:
recursivelySquareEachElt(List(1, 3, 4, 5, 6, 110, 12, 2), Nil)

[36mres26[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m1[39m, [32m9[39m, [32m16[39m, [32m25[39m, [32m36[39m, [32m12100[39m, [32m144[39m, [32m4[39m)

Using the map operator over lists.

In [27]:
def squareEachElt(l: List[Int]): List[Int] = l.map( (x: Int) => x*x ) 
// x => x * x is an anonymous function that squares its arguments.

defined [32mfunction[39m [36msquareEachElt[39m

In [28]:
squareEachElt(List(1, 3, 4, 5, 6, 110, 12, 2))

[36mres28[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m1[39m, [32m9[39m, [32m16[39m, [32m25[39m, [32m36[39m, [32m12100[39m, [32m144[39m, [32m4[39m)

`l.map(f)` says that apply the function `f` on each element of the list `f`.

First of all, the elements of the lists must be some type `A`. 
Then, the function `f` must be of type `A => B`.

`l.map(f)` applies `f` to every element in the list and returns a new list
of type `List[B]`. Following is a recursive definition of this function. Note that this function is an example of a [*polymorphic*](https://docs.scala-lang.org/tour/polymorphic-methods.html) function in Scala. (Can you make it tail recursive?)

In [29]:
def listMap[A,B](lst: List[A], fun: A => B): List[B] = lst match {
 case Nil => Nil
 case hd :: tail => fun(hd) :: listMap(tail, fun) // :: is the Cons operator in scala.
}

defined [32mfunction[39m [36mlistMap[39m

In [30]:
def sayHelloTo(l: List[String]): List[String] = listMap(l, x => ("Hello "+ x)) // Type of x is inferred by Scala

defined [32mfunction[39m [36msayHelloTo[39m

In [31]:
sayHelloTo(List("Cat", "Dog", "World"))

[36mres31[39m: [32mList[39m[[32mString[39m] = [33mList[39m([32m"Hello Cat"[39m, [32m"Hello Dog"[39m, [32m"Hello World"[39m)

## Filter Operation

Just like we have used map to apply a function to each element and make a new container, we similarly use `filter` but to remove all elements that do not satisfy a predicate.

__Predicate:__ A predicate is a function that takes in a value and returns true/false.

In [32]:
def retainAllMultiplesOfThree(l: List[Int], acc: List[Int] = Nil): List[Int] = l match {
 case Nil => acc
 case hd :: tail => {
 val newAcc = if (hd % 3 == 0) { acc ++ List(hd)} else { acc }
 retainAllMultiplesOfThree(tail, newAcc)
 }
}

defined [32mfunction[39m [36mretainAllMultiplesOfThree[39m

`l.filter(c)` filters all those elements that do not satisfy the condition `c` from the list `l`.

In [33]:
def retainAllMultiplesOfThree(l: List[Int]): List[Int] = {
 l.filter( x => x%3 == 0 )
}

defined [32mfunction[39m [36mretainAllMultiplesOfThree[39m

In [34]:
retainAllMultiplesOfThree(List(10, 15, 18, 12, 3, 1, 5, 7, 8, 14))

[36mres34[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m15[39m, [32m18[39m, [32m12[39m, [32m3[39m)

Here is how the filter operation is defined abstractly

In [35]:
def filterList[A] (lst: List[A], filterFun: A => Boolean): List[A] = lst match {
 case Nil => Nil
 case head :: tail => {
 if (filterFun(head)){
 head :: filterList(tail, filterFun)
 } else {
 filterList(tail, filterFun)
 }
 }
}

// Ths is not tail recursive. Why? Can you make it tail recursive?

defined [32mfunction[39m [36mfilterList[39m

## Fold Operations

Fold/reduce operations are useful to gather all data thus far during a computation. Take a list

$$[l_1, l_2, \ldots, l_n] $$

We wish to sum up the numbers in the list.
This is achieved in a loop with accumulator.
~~~
acc = 0
for each item in List
 acc = acc + item
return acc
~~~

We can also do it with fold left operator.

As an example consider the sum of the elements of a list above.



In [36]:
def recSumOfList(lst: List[Int], sum: Int = 0): Int = lst match {
 case Nil => sum
 case hd::tail => recSumOfList(tail, sum + hd)
}

defined [32mfunction[39m [36mrecSumOfList[39m

Fold is a tricky operation to wrap one's head around. A list data structure gives us two versions of fold.

### list.foldLeft (startVal) (fun)

For list `[l1, l2, l3, ..., ln]` the function call computes the following unrolled function:

` fun(.... fun( fun ( fun( startVal, l1), l2), l3), ....., ln)`
This is equivalent to the following Scala code:

~~~
var acc = startVal
for (lj <- list)
 acc = fun(acc, lj) // Very imp: acc is the first argument and lj is the second argument.
~~~



### list.foldRight (startVal) (fun)

This iterates the list from right to left. To wit, list `[l1, l2, l3, ..., ln]` the function call computes the following unrolled function:

` fun(l1, fun(.....,fun(ln-2, fun(ln-1, fun(ln, startVal)))`

This is equivalent to the following scala code:

~~~
var acc = startVal
for (lj <- list.reverse) // Note list is iterated in reverse
 acc = fun(lj, acc) // Very imp: acc is the second argument for foldRight
~~~

The fold function has two arguments: `startVal` and `fun`. Why don't we write: 
`list.foldLeft(startVal, fun)`? This is a special syntax for writing functions with multiple argument 
in Scala called __curried syntax__

https://alvinalexander.com/scala/fp-book/partially-applied-functions-currying-in-scala

We will talk about currying in detail later on (in a few weeks) and it has nothing to do with Indian cuisine.
 

In [37]:
def sumList(l: List[Int]): Int = l.foldLeft (0) ((acc, x) => acc + x )
// Fold left with initial value of accumulator = 0
// Every time we have a new list element x and accumulator value acc, update acc by acc + x

defined [32mfunction[39m [36msumList[39m

In [38]:
sumList(List(1, 2, 3,4, 5, 6, 7, 8, 9, 10))

[36mres38[39m: [32mInt[39m = [32m55[39m

In [39]:
def sumFromRight(l: List[Int]) : Int = l.foldRight (0) ((x, acc) => x + acc)

defined [32mfunction[39m [36msumFromRight[39m

Let us now write a function `reverseList`

In [39]:
def reverseList(l: List[Int]): List[Int] = 
l.foldLeft (Nil) ( (listSoFar: List[Int], elt: Int) => {
 elt::listSoFar
} )

cmd40.sc:3: type mismatch;
 found : List[Int]
 required: collection.immutable.Nil.type
 elt::listSoFar
 ^
Compilation Failed

What just happened? Scala's type checker bailed on us.

- Nil is the empty list for any type: List[String], List[Int], List[Double], List[List[List[Int]]], and so on.
- The type checker is simply not sophisticated enough to figure out that the type of the accumulator in foldLeft here must be a list of int. 

There are two fixes.


In [40]:
def reverseListA(l: List[Int]): List[Int] = 
 l.foldLeft ( List[Int]() ) ( (listSoFar: List[Int], elt: Int) => {
 elt::listSoFar
} )

defined [32mfunction[39m [36mreverseListA[39m

In [41]:
def reverseListB(l: List[Int]): List[Int] = 
 l.foldLeft[List[Int]] ( Nil ) ( (listSoFar: List[Int], elt: Int) => {
 elt::listSoFar
} )

defined [32mfunction[39m [36mreverseListB[39m

In general, it is always nice to have the type of the accumulator specified in fold left. Last but not least, note that the anonymous function in fold can be written in case pattern form. Note the syntax for the 2nd argument of `foldLeft`. Instead of enclosing it in `( )`, we use `{ }`. When a function literal is the last argument to a method, you can replace the usual parentheses with curly braces.

In [42]:
def reverseListC(l: List[Int]): List[Int] = l.foldLeft[List[Int]] (Nil) {
 case (listSoFar: List[Int], elt: Int) => elt::listSoFar
}

defined [32mfunction[39m [36mreverseListC[39m

In [43]:
reverseListA(List(1,2,3,4))

[36mres43[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m4[39m, [32m3[39m, [32m2[39m, [32m1[39m)

In [44]:
reverseListB(List(1,2,3,4))

[36mres44[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m4[39m, [32m3[39m, [32m2[39m, [32m1[39m)

In [45]:
reverseListC(List(1,2,3,4))

[36mres45[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m4[39m, [32m3[39m, [32m2[39m, [32m1[39m)