As I was learning Scala and first read the phrase Partial Function had a weird thing about it. How can a function be partial?? I started digging into the documentation and online. What came out was a surprisingly very clean solution to a problem which every programmer encountered.

Every function or method which is written usually applies some validation rule to make sure that the parameters which are passed to it are valid. It makes the code a bit messy. We had the logic validating the parameters and then the actual business logic which is invoked if the parameters were correct. That means that the function only works on a subset of parameters. Partial functions provide a much better solution.

So now let’s look at the definition of a *partial function*. Quoting from Scala Doc

A partial function of type PartialFunction[A, B] is a unary function where the domain does not necessarily include all values of type A.

Partial functions in scala are implemented as a trait. It takes two types A, B. Where A is the input Scala type and B is what the partial function supposed to return. So, for example, the partial function takes String and returns an Int then A is String and B is Int.

The trait has two abstract methods

- apply
- isDefinedAt

Let’s see how to create a partial function

We have now created two functions – * apply* which has the actual logic which needs to be applied if it supplied with a valid value.

*is a function which checks if the parameter being used when calling the partial function is valid or not. So a clean separation of business logic and data validation logic. However, be aware that isDefinedAt method is not automatically called, it needs to be called manually.*

**isDefinedAt**Let’s extend this further by calling the partial function

def apply(s:String) = {

println("Hello from apply function!")

s.length

}

def isDefinedAt(s:String) = {

println("Hello from isDefinedAt function")

s.length>3

}

}

val str = "Hello World"

//Check if parameter is valid or not

val validValue = newLength.isDefinedAt(str)

//Prints - true

println(validValue)

//Pass if the parameter to the partial function

val len = newLength(str)

//Prints - 11

println(len)

Before we dive deeper into partial functions let’s see another way of writing partial functions using the * case* statement. Using the case statement allows us to combine the apply and isDefinedAt function and reduce the amount of code. However, the result is still the same. Keep in mind this use of case though similar to match case is not the same. The only difference is that if the parameter does not pass the condition it throws an exception. See Below

case str:String if str.length>3 => str.length

}

val str = "Hello World"

//Check if parameter is valid or not

val validValue = newLength.isDefinedAt(str)

//Pass if the parameter to the partial function

val len = newLength(str)

//Prints - 11

println(len)

This does not sound like a big deal but let’s explore another feature of the partial functions which will make all this worthwhile. Here goes

**orElse**** andThen** are two methods in the Partial Function trait which allow chaining multiple partial functions, which allows two functions to work together on different parts of the same problem but we can all call them all together.

For example, if we have a list of numbers and we want to multiply the even numbers by 5 and multiply the odd numbers by 10. We can chain two or more partial functions in a single partial function. See Below

val multiplyBy5 : PartialFunction[Int, Int] = {

case n:Int if n%2==0 => n * 5

}

//Partial function 2

val multiplyBy10 : PartialFunction[Int, Int] = {

case n:Int if n%2==1 => n * 10

}

//Partial function 3

val divideBy2 : PartialFunction[Int, Int] = {

case n:Int n => n / 2

}

//Create a new partial function by joining three partial functions

//multiply the list elements by 5 if they are even and multiply by 10

//if they are even and then finally divide the result by 2

val multiply = multiplyBy5 orElse multiplyBy10 andThen divideBy2

val list = List(1,2,3,4,5,6,7,8,9,10)

//Prints - 5, 5, 15, 10, 25, 15, 35, 20, 45, 25

list.map(multiply).foreach(println)

Until next time.