Sunday, July 31, 2011

Partially unexpected effects of chaining partial functions in Scala

People who learn Scala usually agree that pattern matching is a great feature which helps make your code more expressive. Some time later they also discover that partial functions used in the match statement can also be used separately. And since partial functions are full-blown functions, you can combine them in a couple useful ways:



f1 orElse f2

Combines two partial function in a new one; if f1 is not defined in its argument, tries f2.

f1 andThen f2

Applies the result of f1 to f2. This means that the output of f1 must be a type compatible with the input of f2



orElse


As others have discovered, you can combine a list (TraversableOnce really) of partial functions into one with reduce. What's not so obvious though is that the way you combine them can lead to unexpected perfomance consequences.


In order to easily create a lot of partial functions to test, we will create a higher-order function to generate them (if you're used to Java, you can call it a factory). The produced partial function will print a short message when its isDefinedAt method is called (not when it's applied):



def genPF(defined: Int): PF = { case i if {println(defined); defined == i} => i }
val f123 = 1 to 3 map genPF reduceLeft(_ orElse _)


Let's try it:



> f123(1)
1
1
1
> f123(2)
1
2
1
2
1
2
> f123(3)
1
2
3

Wait, what? The isDefinedAt method is called up to 6 times. It gets even worse with a bigger number of composed functions.



val f1to5 = 1 to 5 map genPF reduceLeft(_ orElse _)



> f1to5(2)
1
2
1
2
1
2
1
2
> f1to5(3)
1
2
3
1
2
3
1
2
3
> f1to5(4)
1
2
3
4
1
2
3
4

Let's take a closer look at the definition of isDefinedAt and apply of the function created with orElse:



def orElse[A1 <: A, B1 >: B](that: PartialFunction[A1, B1]) : PartialFunction[A1, B1] =
new PartialFunction[A1, B1] {
def isDefinedAt(x: A1): Boolean =
PartialFunction.this.isDefinedAt(x) || that.isDefinedAt(x)
def apply(x: A1): B1 =
if (PartialFunction.this.isDefinedAt(x)) PartialFunction.this.apply(x)
else that.apply(x)
}


When you apply a composed partial function, we first check if it's defined in either f1 or f2, and then we check f1 again, so that we know which one to call. This means that in the worst case, isDefinedAt for f1 is called twice.


Given this, we can explain what happens here. The isDefinedAt delegates to the composed functions' methods, and when it's called twice... you know what happens when we do this again and again. We can fairly easily find out that isDefinedAt is called k * (n - k + 1) times, where n: number of composed functions, k: the first function that matches.


Luckily, there is an easy solution to combine partial functions in a more efficient way. We can use reduceRight, where isDefinedAt for each composed function is invoked at most twice. Verifying this and finding out why is left as an exercise for the curious reader (as you undoubtedly are, since you're reading this).




andThen


You would think that f1 andThen f2 should be defined only in the cases when the results of f1 are defined in f2



val doubler: PartialFunction[Int,Int] = { case i if Set(1,2) contains i => i * 2 }
val f = doubler andThen doubler


Of course, that's not how it works. In order to find out if the output of f1 is a valid input for f2, we would need to execute the function, and it's better not to do this in case it has side effects. This means that we cannot rely on calling isDefined for the combined function to avoid MatchErrors:



> f isDefinedAt 2
res3: Boolean = true
> f(2)
scala.MatchError: 4
...

Conclusion: when you're looking for performance, it always help to understand how the abstractions you're using decompose into simpler building blocks.