Sunday, November 11, 2012

Saddling the horse with type-classes in Scala

Preface

While reading this interesting and well-written blog post on Scala, Functional Programming and Type Classes, at some point I got particularly interested in the section entitled "Yes Virginia, Scala has Type-Classes" (section 4) where the author introduces type-classes as one of Scala's more powerful features by giving an easy-to-understand example.
While this example might perhaps be a little too much for Scala beginners who happen to encounter type-classes (and implicits) for the first time, it is still well suited to show off parts of the power and beauty of Scala, especially in terms of type-safety and, to some degree, also its outstanding collections library.
Well, there is maybe one minor drawback in the mentioned example, namely that it loses type information where it possibly shouldn't which is what I would like to address with this blog-post.
For beginners, there is quite a long introduction to type-classes in Scala, along with numerous examples. Others might want to skip that "detour" and jump right to where the fun starts.

Detour

If you are already familiar with what type-classes are and where they usually come in handy, you might want to skip this section. Nevertheless, if that is not the case, I will try to give you a quick (and, agreed, incomplete) introduction. First and foremost, however, I would like to refer to this excellent paper on Type Classes as Objects and Implicits.

Alright, so what are type-classes and why would I care?

The original concept of type-classes was developed years ago in the realm of Haskell with respect to adhoc polymorphism. They are of course not restricted to the Haskell world. In fact, in Scala-land they have numerous applications as well, namely

  • require and/or provide certain characteristics for parameter types in generic programming (so-called context bounds)
  • add behaviour to existing classes without having access to or modifying their respective code (so-called view bounds),
  • provide a mechanism for type-level programming.
As I have already talked about the third item in "full glory" (well, almost, see my related blog post), this time I will only talk about the first two items.

Context bounds

According to the Scala Language Specification:

A type parameter A of a method or non-trait class may also have one or more context bounds A : T. In this case the type parameter may be instantiated to any type S for which evidence exists at the instantiation point that S satisfies the bound T. Such evidence consists of an implicit value with type T[S].
So, what does this mean? Well, this is probably best explained in terms of some examples. There are three typical use-cases (in fact it is just one and the same thing but each of the following serves a slightly different purpose) to be found with respect to context bounds, one of which is...

(#1) Restricting a generic function to parameters of a certain type

Suppose you wanted to implement a function

      def doSomething[T](x: T) = ???
    

which works on some x of type T. Furthermore, suppose you wanted to restrict T to certain types. If those types were to adhere to a particular class hierarchy, the solution would be easy, say for all subclasses of MySuperClass:

      def doSomething[T <: MySuperClass](x: T) = ???
    

But what would you do if you wanted to restrict T to arbitrary classes having no relationship to each other at all, e.g. MySuperClass, Double, and Vector[Double]? What you actually could do is define a trait together with a set of values such as

      sealed trait CanDoSomething[T]

      object mySuperClassCanDoSomething extends CanDoSomething[MySuperClass]
      object doubleCanDoSomething extends CanDoSomething[Double]
      object vectorOfDoubleCanDoSomething extends CanDoSomething[Vector[Double]]
    

and then, in a first step, refurbish the above method's declaration as follows:

      def doSomething[T](x: T)(ev: CanDoSomething[T]) = ???
    

Don't worry, that's not the end. What we have done so far is nothing but define doSomething as a method which takes two arguments in a curried fashion where the second argument merely serves as evidence for the existence of a parameterized instance of CanDoSomething, hence effectively introducing a context bound on T. Now any caller of this function could only do so once she had a suitable value at hand.

However, having to call doSomething like that and having to import those values or pollute namespaces is arguably clumsy, which is exactly the point where type-classes for context bounds emerge in the form of implicits.

Let us hence define the following companion object to CanDoSomething and move our values inside that object, only this time prefixing them with the keyword implicit:

      trait CanDoSomething[T]
      
      object CanDoSomething {
        implicit object MySuperClassCanDoSomething extends CanDoSomething[MySuperClass]
        implicit object DoubleCanDoSomething extends CanDoSomething[Double]
        implicit object VectorOfDoubleCanDoSomething extends CanDoSomething[Vector[Double]]
      }
    

We also prepend the keyword implicit to the second parameter list of doSomething

      def doSomething[T](x: T)(implicit ev: CanDoSomething[T]) = ???
    

which can be written in a shorter form as (this is really just syntactic sugar)

      def doSomething[T: CanDoSomething](x: T) = ???
    

Now we have two options: We can still call doSomething and explicitly provide a parameterized instance of CanDoSomething, or we can leave away that second argument and have the compiler find the most suitable value implicitly.

At this point, you would typically counteract and argue that this were some sort of compiler magic and hence it ought to be burned at the stake, but it is really not. There are strict rules for the determination and the priority of implicits such that there is always only one "winner", and that is chosen deterministically. For the detailed rules I would suggest to have a look at the Scala Language Specification.

(#2) Having type-classes provide behaviour

The second, closely related, and probably the more important use-case is having type-classes provide behaviour (read: methods) for certain parameterized instances.

Not yet following along? Then let me give you another example: Suppose you wanted to implement a generic function that scales all elements of a given collection according to a given factor. How would you do that? Well, let us roll up the sleeves and try the following naive approach:

      def scale[T](xs: Traversable[T], factor: T) =
        xs map (x => x * factor)
    

But what happens once we compile the above definition is that the compiler consequently complains about the missing symbol *:

      error: value * is not a member of type parameter T
               xs map (_ * factor)
                         ^
    

This is usually the point where confusion starts. Why ffs would the compiler not know how to multiply two elements? Isn't that just silly?
Well, yes, and no: From our point of view, we probably meant to multiply some numbers, but what we actually defined was a generic function for a parameter T which actually happens to represent any possible type, i.e. not only numeric types like Int, Double or Float, but also e.g. String, Vector or even BufferedInputStream. Now, how is the compiler supposed to know how to multiply instances of the latter classes or, more precisely, call a method obj.* on one such instance obj?

Luckily, we can use type-classes to the rescue again. For example, we could define the following trait and companion object (please note that in order to avoid confusion I will be using times instead of *)

      trait Numeric[T] {
        def times(x: T, y: T): T
      }
      
      object Numeric {
        implicit object IntIsIntegral extends Numeric[Int] {
          def times(x: Int, y: Int): Int = x * y
        }
        implicit object CharIsIntegral extends Numeric[Char] {
          def times(x: Char, y: Char): Char = (x * y).toChar
        }
        // ... something similar for Double, Float and all other
        // classes for which you explicitly want to provide the
        // corresponding behaviour.
      }
    

and subsequently redefine scale as

      def scale[T](xs: Traversable[T], factor: T)(implicit num: Numeric[T]) =
        xs map (x => num.times(x, factor))
    

which would then enable us to call scale for collections of all types for which we have explicitly provided a corresponding instance of type-class Numeric, e.g.

      scala> scale(List(1,2,3), 4)
      res1: Traversable[Int] = List(4, 8, 12)

      scala> scale(List('A', 'B'), 1.toChar)
      res2: Traversable[Char] = List(A, B)
    

It is in fact no coincidence that browsing through the documentation of the Scala Standard Library reveals that such a type-class has already been implemented in a similar (read: better) way.

Still, some people might argue that most if not all of the above could be achieved by a number of overloaded methods in the respective classes, e.g. Int.times(Int), Int.times(Double), Double.times(Double), and so on. But then what would they do if they wanted to introduce another - new - numeric type, say MySuperDuperBigInt? First, they would have to provide new overloaded methods for all existing classes (Int.times(MySuperDuperBigInt), etc.) which means modifying tons of code and is probably error-prone, and secondly, how would they do that if they had no access to the code of the existing classes, e.g. to classes from the Java library (JDK)?

(#3) Increasing type-safety (possibly adding behaviour)

The third use-case is rather a combinaton of the former two. Let us start with another example where we wish to write a generic function that takes a collection of values and returns a probability density function for the corresponding random variable (a PDF can be understood as a function that evaluates to the relative likelihood of a random variable to take a given value).

Our function should return new instances of subclasses of PDF depending on the element-type of the given input collection. Precisely, we want a new instance of PDFDiscrete for e.g. collections of String whereas for collections of Double we would like to get a new instance of PDFGaussian.
A typical first attempt at defining this could be

      trait PDF[T] extends (T => Double)
      trait PDFDiscrete[T] extends PDF[T]
      trait PDFGaussian[T] extends PDF[T]

      def createPDF[T](lst: Traversable[T]): PDF = ???
    

where createPDF would do all the dirty work for us and return a new instance of either PDFDiscrete or PDFGaussian.

So far, so good, but didn't we lose information about the most specific type on the way? Oh boy, we did! While the dynamic type of the return value is known at runtime, we (read: the compiler) have no way of telling the most specific static type of the return value once control flow is returned to the caller, i.e. we will only know that we got a subclass of PDF but not which one exactly (PDFDiscrete, PDFGaussian).

But there really is no reason why we should go without the exact type since we want to write type-safe programs, don't we? It is not surprising that this issue can be solved with type-classes (which in this case, but not necessarily, also know how to handle the given input):

      trait CanBuildPDF[Elem, That] {
        def apply(xs: Traversable[Elem]): That
      }
      
      object CanBuildPDF {
        implicit object DiscreteFromStrings extends CanBuildPDF[String, PDFDiscrete[String]] {
          def apply(xs: Traversable[String]) =
            new PDFDiscrete[String] {
              // Determine the number of occurences of each distinct String ...
              def apply(x: String) = throw new Error("Not yet implemented.")
            }
        }
        
        implicit object GaussianFromDoubles extends CanBuildPDF[Double, PDFGaussian[Double]] {
          def apply(xs: Traversable[Double]) =
            new PDFGaussian[Double] {
              // Determine the mean and standard deviation from the given values ...
              def apply(x: Double) = throw new Error("Not yet implemented.")
            }
        }
      }
    

Here, we have a type-class CanBuildPDF which is parameterized with respect to the type Elem of the elements of the given input collection as well as the type That of the corresponding subclass of PDF (note that, for the sake of readability, I went without further contraints on That such as That <: PDF[Elem]). Each instance of this type-class knows how to create the respective subclass of PDF according to its respective input elements by means of implementing the abstract method CanBuildPDF.apply in a type-safe manner.

We would subsequently redefine createPDF so as to use (and require) one such instance as

      def createPDF[Elem, That](lst: Traversable[Elem])
          (implicit bldr: CanBuildPDF[Elem, That]): That = 
        bldr(lst)
    

and, summa summarum, that will give us exactly what we wanted:

      scala> createPDF(List("Hello", "world"))
      res20: PDFDiscrete[String] = 
      
      scala> createPDF(Array.fill(50000)(util.Random.nextDouble))
      res21: PDFGaussian[Double] = 
    

View bounds

According to the Scala Language Specification:

A type parameter A of a method or non-trait class may have one or more view bounds A <% T. In this case the type parameter may be instantiated to any type S which is convertible by application of a view to the bound T.

Alright, so here we have the promise that once we give the compiler the means of implicitly viewing a value of type A as if it were a value of type T, it will happily do so and hence enable us to e.g. call any method etc. provided by T on the original value which is of type A.

How on earth is this useful, you might ask? And again, we have several use-cases which in fact all boil down to the very same thing. So let us go on.

(#1) Implicit conversion

For one thing, suppose you had the following representation of an n-dimensional vector:

      class Vec(val elems: Seq[Double]) {
        require(elems.length >= 1)

        def +(other: Vec): Vec =
          new Vec(for ((x,y) <- elems zip other.elems) yield x+y)

        def *(other: Vec): Double =
          (for ((x,y) <- elems zip other.elems) yield x*y) sum

        override def toString = "Vec(" + elems.mkString(",") + ")"
      }
      object Vec {
        def apply(x: Double, xs: Double*) = new Vec(x +: xs)
      }
    

Agreed, the implementation is just bad, but that is not the point here. The point is that we might wish for some kind of automatic conversion from e.g. numeric literals or even tuples to instances of Vec in order to ease syntax a bit. So instead of having to write Vec(1,2,3) + Vec(4,5,6) we prefer to express the same matter as e.g. v + (4.,5.,6.) for some instance v of Vec, or even (1.,2.,3.) * (4.,5.,6.). How would we do that?

Turns out we can deal with this by using a number of corresponding implicit conversions (they really should go into Vec's companion object):

      implicit def doubleToVec(x: Double) = new Vec(List(x))
      implicit def tuple2ToVec(x: (Double, Double)) = new Vec(List(x._1, x._2))
      implicit def tuple3ToVec(x: (Double, Double, Double)) = new Vec(List(x._1, x._2, x._3))
    

And consequently get

      scala> import Vec._
      import Vec._

      scala> 1 + Vec(3)
      res1: Vec = Vec(4.0)

      scala> (2.,4.) + Vec(4,5)         // Note the explicit use of Doubles
      res2: Vec = Vec(6.,9.)            // in the left-hand tuple

      scala> (1.,2.,3.) * (4.,5.,6.)    // Duples in both tuples
      res3: Double = 32.0
    

So far, we had to be explicit about using Double for the values of the tuple elements in the above example, and that was due to the fact that we declared our implicit conversions with respect to Doubles, only. Can we maybe also get rid of this?

Yes, sure! Similarly to what we did before in one of the examples for context bounds, we could redefine our implicit conversions such that they work for all types T for which there is a function that turns any given T into a Double. Of course that's an implicit view again! Let us hence proceed (so as to avoid more code duplication I'll only treat tuple3ToVec, the rest is of course analogous):

      implicit def tuple3ToVec[T](x: (T, T, T))(implicit conv: T => Double) =
        new Vec(List(x._1, x._2, x._3))
    

Here, conv is the function that takes a T and gives us a corresponding Double. Note that this argument needs to be implicit itself due to the fact that since tuple3ToVec is implicit there's no way for us to explicitly provide a function value for conv. Luckily, the Scala library designers have already implemented such a method for us, so we don't need to do it ourselves.
Also note that we're not explicitly using conv. Well, we could, but we don't need to as it will (implicitly) be passed on and applied at the appropriate place:

      scala> (1,2,3) + Vec(2.,3.,4.)
      res4: Vec = Vec(3.0,5.0,7.0)
    

Lastly, since implicit conversion are used so often (gosh, they're just so hideous), there is a more compact syntax for us syntax-lovers out there:

      implicit def tuple3ToVec[T <% Double](x: (T, T, T)) = new Vec(List(x._1, x._2, x._3))
    

(#2) Pimp My Library™

One very useful thing is the ability to add behaviour to existing classes without the need to access their original code. This way, you could e.g. add methods to classes for which you feel these methods are really missing. For example, let us propose that every String should have a method runLengthEncoded which yields the corresponding encoded representation of that string. Let us start out with some kind of wrapper class that is supposed to enrich the original class String:

      class RichString(s: String) {
        lazy val runLengthEncoded: List[(Char, Int)] = {
          def auxRLE(xs: List[Char]): List[(Char, Int)] = xs match {
            case Nil => Nil
            case h :: _ =>
              val (buf, rem) = xs span (_ == h)
              (h, buf.length) :: auxRLE(rem)
          }
          auxRLE(s.toList)
        }
      }
    

This then gives us

      scala> new RichString("aaabbaac").runLengthEncoded
      res1: List[(Char, Int)] = List((a,3), (b,2), (a,2), (c,1))
    

which is not yet what we wanted. Can we somehow use implicit conversions to improve this? Well, you already know the answer. We simply define one such conversion from String to RichString as

      implicit def stringToRichString(s: String) = new RichString(s)
    

Now that we have successfully pimped the original class, we can use runLengthEncoded on just any instance of String:

      scala> "aaabbaac".runLengthEncoded
      res2: List[(Char, Int)] = List((a,3), (b,2), (a,2), (c,1))

      scala> val german="Doppelkupplungsgetriebe"; german.runLengthEncoded
      res3: List[(Char, Int)] = List((D,1), (o,1), (p,2), (e,1), (l,1), (k,1), (u,1), (p,2), (l,1), (u,1), (n,1), (g,1), (s,1), (g,1), (e,1), (t,1), (r,1), (i,1), (e,1), (b,1), (e,1))
    

(#3) Domain Specific Languages

DSLs allow for the extension of the syntax without touching the language core at all. If not overused, they're absolutely awesome, and their use is wide-spread, for example in many well-known specs or unit-test frameworks, or Akka, or just for fun.

To explain all aspects of DSLs is surely out of the scope of this document, but you should definitely go ahead and investigate this topic. By now you probably have more than just an idea about how things like these might work behind the scenes (taken from the user guide of ScalaTest):

      Array(1, 2) should equal (Array(1, 2))

      string should endWith ("world")

      sevenDotOh should be (6.9 plusOrMinus 0.2)

      map should contain key (1)
      map should contain value ("Howdy")
    

No? Implicits and friends ring a bell? Yep, that went well ☺

Riding the horse without saddle (the actual topic)

The OP's example (which I'm going to modify slightly) is all about folding over the elements of a given parameterized collection. Such a fold could e.g. be used to sum up all elements of a Vector. A typical naive attempt for this might be

      def sum[T](lst: Traversable[T]): T =
        (lst foldLeft 0)(_ + _)
    

which of course fails for two reasons: one, there is no + defined for every possible type T, and two, 0 is certainly no meaningful value for anything but numbers.

Solving this problem with the help of type-classes is rather easy. So we start out with:

      trait CanFold[-Elem, Out] {
        def sum(acc: Out, elem: Elem): Out
        def zero: Out
      }

      object CanFold {
        implicit object CanFoldInts extends CanFold[Int, Long] {
          def sum(acc: Long, elem: Int): Long = acc + elem
          def zero = 0
        }
      }

      def sum[Elem, Out](lst: Traversable[Elem])(implicit cf: CanFold[Elem, Out]): Out = 
        (lst foldLeft cf.zero)(cf.sum)
    

Let's check this:

      scala> sum(List(1,2,3,4,5,6))
      res0: Long = 21

      scala> sum(1 to 100)
      res1: Long = 5050
    

Looking good so far. But what if we now wanted to fold over collections of collections? At a first glance, that should be no problem at all, given another corresponding instance of CanFold.

      implicit def canFoldTraversables[T] =
        new CanFold[Traversable[T], Traversable[T]] {
          def sum(acc: Traversable[T], elem: Traversable[T]): Traversable[T] = acc ++ elem
          def zero = Traversable.empty[T]
        }
    

And off we go:

      scala> sum(List(List(1,2,3), List(4,5,6)))
      res1: Traversable[Int] = List(1, 2, 3, 4, 5, 6)
    

See the problem there? We are losing information about the actual type of the collection since we got back a Traversable where we really wanted a List (or whatever the designers of Scala deemed applicable depending on what we stuck in there).

Saddling the horse

What we actually need is a way to determine the most specific type of the returned collection that we can get. But how can we do that?

Well, there is this especially useful type-class which is to be found in the collections library, namely CanBuildFrom, and we set off with this:

      import scala.collection.generic.CanBuildFrom

      implicit def canFoldTraversables[Repr[T] <: Traversable[T], T, That]
        (implicit cbf: CanBuildFrom[Repr[T], T, That]): CanFold[Repr[T], That] =
        new CanFold[Repr[T], That] {
          def sum(acc: That, elem: Repr[T]) = acc ++ elem
          def zero = cbf().result
        }
    

For one thing, we made use of CanBuildFrom only to determine the most specific return type That. As soon as we try to compile that code, though, the compiler will consequently tell us that "value ++ is not a member of type parameter That". Well, darn, we should have thought of this before. How about introducing an upper bound then?

Turns out that defining an upper bound is not enough because, when we have That <: Traversable[T], That could again be just any subtype of Traversable, somewhere deep down below the type hierarchy, but the result of ++ is a Traversable and hence violates our very own type constraints.

We could thus conclude that we need a lower bound as well, leaving us with That >: Traversable[T] <: Traversable[T], but that would be nonsense or rather yield nothing more than the original approach where we always lost the most specific type. And keeping that type was our mission, remember?

So, let us rewind and see if we can still find a working solution with respect to That <: Traversable[T]. First, we remember that the compiler complained about the result type of ++. Second, we realize that we could have used the CanBuildFrom instance to create the return value of sum:

      import scala.collection.generic.CanBuildFrom

      // Note the use of type constructor `That[T]` instead of `That`
      implicit def canFoldTraversables[Repr[T] <: Traversable[T], T, That[T] <: Traversable[T]]
        (implicit cbf: CanBuildFrom[Repr[T], T, That[T]]): CanFold[Repr[T], That[T]] =
        new CanFold[Repr[T], That[T]] {
          def sum(acc: That[T], elem: Repr[T]) = {
            val builder = cbf()
            builder ++= acc
            builder ++= elem
            builder.result
          }
          def zero = cbf().result
        }
    

Good lord, this is ugly! But it yields our first working solution:

      scala> foldFoo.sum(List(List(1,2,3), List(4,5,6)))
      res0: List[Int] = List(1, 2, 3, 4, 5, 6)
      
      scala> foldFoo.sum(List(Set(1,2,3), Set(3,4,5)))
      res1: scala.collection.immutable.Set[Int] = Set(5, 1, 2, 3, 4)
      
      scala> foldFoo.sum(Vector(List(1,2,3), Set(3,4,5)))
      res2: scala.collection.immutable.Iterable[Int] = List(1, 2, 3, 3, 4, 5)
    

Can we do better? The answer to this rhetorical question is of course yes (you can get more from details from Josh Suereth's answer to a related question on SO). By using TraversableLike instead of Traversable we get more info on the type than what we got before. However, we must drop That in favor of Repr or else we will run into trouble with the type inferrer:

      import scala.collection.generic.CanBuildFrom
      import scala.collection._

      implicit def canFoldTraversables[Repr[T] <: TraversableLike[T, Repr[T]], T]
        (implicit cbf: CanBuildFrom[Repr[T], T, Repr[T]]): CanFold[Repr[T], Repr[T]] =
        new CanFold[Repr[T], Repr[T]] {
          def sum(acc: Repr[T], elem: Repr[T]) = acc ++ elem
          def zero = cbf().result
        }
    

This concludes this rather longish blogpost with our final working solution. Maybe the code could still have been improved in terms of using That and hence not restricting the type of the resulting collection too much, but so far I couldn't figure out how.

Please note that it was my intention to demonstrate the development cycle in this last example with all its one-way streets as I think that this is something that we all experience every now and then.

Thank you very much for reading along! The code examples can be found here. Any constructive feedback is appreciated.

No comments:

Post a Comment