Dijkstra with Type Families

In the previous part of this series, I wrote about a more general form of Dijkstra’s Algorithm in Haskell. This implementation used a Multi-param Typeclass to encapsulate the behavior of a “graph” type in terms of its node type and the cost/distance type. This is a perfectly valid implementation, but I wanted to go one step further and try out a different approach to generalization.

Type Holes

In effect, I view this algorithm as having three “type holes”. In order to create a generalized algorithm, we want to allow the user to specify their own “graph” type. But we need to be able to refer to two related types (the node and the cost) in order to make our algorithm work. So we can regard each of these as a “hole” in our algorithm.

The multi-param typeclass fills all three of these holes at the “top level” of the instance definition.

class DijkstraGraph graph node cost where
    ...

But there’s a different way to approach this pattern of “one general class and two related types”. This is to use a type family.

Type Families

A type family is an extension of a typeclass. While a typeclass allows you to specify functions and expressions related to the “target” type of the class, a type family allows you to associate other types with the target type. We start our definition the same way we would with a normal typeclass, except we’ll need a special compiler extension:

{-# LANGUAGE TypeFamilies #-}

class DijkstraGraph graph where
  ...

So we’re only specifying graph as the “target type” of this class. We can then specify names for different types associated with this class using the type keyword.

class DijkstraGraph graph where
  type DijkstraNode graph :: *
  type DijkstraCost graph :: *
  ...

For each type, instead of specifying a type signature after the colons (::), we specify the kind of that type. We just want base-level types for these, so they are *. (As a different example, a monad type like IO would have the kind * -> * because it takes a type parameter).

Once we’ve specified these type names, we can then use them within type signatures for functions in the class. So the last piece we need is the edges function, which we can write in terms of our types.

class DijkstraGraph graph where
  type DijkstraNode graph :: *
  type DijkstraCost graph :: *
  dijkstraEdges :: graph -> DijkstraNode graph -> [(DijkstraNode graph, DijkstraCost graph)]

Making an Instance of the Type Family

It’s not hard now to make an instance for this class, using our existing Graph String Int concept. We again use the type keyword to specify that we are filling in the type holes, and define the edges function as we have before:

{-# LANGUAGE FlexibleInstances #-}

instance DijkstraGraph (Graph String Int) where
    type DijkstraNode (Graph String Int) = String
    type DijkstraCost (Graph String Int) = Int
    dijkstraEdges graph node = fromMaybe [] (HM.lookup node (edges graph))

This hasn’t fixed the “re-statement of parameters” issue I mentioned last time. In fact we now restate them twice. But with a more general type, we wouldn’t necessarily have to parameterize our Graph type anymore.

Updating the Type Signature

The last thing to do now is update the type signatures within our function. Instead of using separate graph, node, and cost parameters, we’ll just use one parameter g for the graph, and define everything in terms of g, DijkstraNode g, and DijkstraCost g.

First, let’s remind ourselves what this looked like for the multi-param typeclass version:

findShortestDistance ::
  forall graph node cost. (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph node cost) =>
  graph -> node -> node -> Distance cost

And now with the type family version:

findShortestDistance ::
  forall g. (Hashable (DijkstraNode g), Eq (DijkstraNode g), Num (DijkstraCost g), Ord (DijkstraCost g), DijkstraGraph g) =>
  g -> DijkstraNode g -> DijkstraNode g -> Distance (DijkstraCost g)

The “simpler” typeclass ends up being shorter. But the type family version actually has fewer type arguments (just g instead of graph, node, and cost). It’s up to you which you prefer.

And don’t forget, type signatures within the function will also need to change:

processQueue ::
  DijkstraState (DijkstraNode g) (DijkstraCost g) ->
  HashMap (DijkstraNode g) (Distance (DijkstraCost g))

Aside from that, the rest of this function works!

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

...

>> findShortestDistance graph1 “A” “D”
Dist 40

Unlike the multi-param typeclass version, this one has no need of specifying the final result type in the expression. Type inference seems to work better here.

Conclusion

Is this version better than the multi-param typeclass version? Maybe, maybe not, depending on your taste. It has some definite weaknesses in that type families are more of a foreign concept to more Haskellers, and they require more language extensions. The type signatures are also a bit more cumbersome. But, in my opinion, the semantics are more correct by making the graph type the primary target and the node and cost types as simply “associated” types.

In the next part of this series, we’ll apply these different approaches to some different graph types. This will demonstrate that the approach is truly general and can be used for many different problems!

Below you can find the full code, or you can follow this link to see everything on GitHub!

Appendix - Full Code

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Dijkstra3 where

import Data.Hashable (Hashable)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

newtype Graph node cost = Graph
   { edges :: HashMap node [(node, cost)] }

class DijkstraGraph graph where
  type DijkstraNode graph :: *
  type DijkstraCost graph :: *
  dijkstraEdges :: graph -> DijkstraNode graph -> [(DijkstraNode graph, DijkstraCost graph)]

instance DijkstraGraph (Graph String Int) where
    type DijkstraNode (Graph String Int) = String
    type DijkstraCost (Graph String Int) = Int
    dijkstraEdges graph node = fromMaybe [] (HM.lookup node (edges graph))

data DijkstraState node cost = DijkstraState
  { visitedSet :: HashSet node
  , distanceMap :: HashMap node (Distance cost)
  , nodeQueue :: MinPrioHeap (Distance cost) node
  }

findShortestDistance :: forall g. (Hashable (DijkstraNode g), Eq (DijkstraNode g), Num (DijkstraCost g), Ord (DijkstraCost g), DijkstraGraph g) => g -> DijkstraNode g -> DijkstraNode g -> Distance (DijkstraCost g)
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue

    processQueue :: DijkstraState (DijkstraNode g) (DijkstraCost g) -> HashMap (DijkstraNode g) (Distance (DijkstraCost g))
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = dijkstraEdges graph node
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
    foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
      let altDistance = addDist (d0 !?? current) (Dist cost)
      in  if altDistance < d0 !?? neighborNode
            then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
            else ds

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]
Previous
Previous

Dijkstra in a 2D Grid

Next
Next

Generalizing Dijkstra's Algorithm