Dijkstra with Type Families
In the previous part of this series, I wrote about a more general form of Dijkstra’s Algorithm in Haskell. This implementation used a Multi-param Typeclass to encapsulate the behavior of a “graph” type in terms of its node type and the cost/distance type. This is a perfectly valid implementation, but I wanted to go one step further and try out a different approach to generalization.
Type Holes
In effect, I view this algorithm as having three “type holes”. In order to create a generalized algorithm, we want to allow the user to specify their own “graph” type. But we need to be able to refer to two related types (the node and the cost) in order to make our algorithm work. So we can regard each of these as a “hole” in our algorithm.
The multi-param typeclass fills all three of these holes at the “top level” of the instance definition.
class DijkstraGraph graph node cost where
...
But there’s a different way to approach this pattern of “one general class and two related types”. This is to use a type family.
Type Families
A type family is an extension of a typeclass. While a typeclass allows you to specify functions and expressions related to the “target” type of the class, a type family allows you to associate other types with the target type. We start our definition the same way we would with a normal typeclass, except we’ll need a special compiler extension:
{-# LANGUAGE TypeFamilies #-}
class DijkstraGraph graph where
...
So we’re only specifying graph
as the “target type” of this class. We can then specify names for different types associated with this class using the type
keyword.
class DijkstraGraph graph where
type DijkstraNode graph :: *
type DijkstraCost graph :: *
...
For each type, instead of specifying a type signature after the colons (::
), we specify the kind of that type. We just want base-level types for these, so they are *
. (As a different example, a monad type like IO
would have the kind * -> *
because it takes a type parameter).
Once we’ve specified these type names, we can then use them within type signatures for functions in the class. So the last piece we need is the edges
function, which we can write in terms of our types.
class DijkstraGraph graph where
type DijkstraNode graph :: *
type DijkstraCost graph :: *
dijkstraEdges :: graph -> DijkstraNode graph -> [(DijkstraNode graph, DijkstraCost graph)]
Making an Instance of the Type Family
It’s not hard now to make an instance for this class, using our existing Graph String Int
concept. We again use the type
keyword to specify that we are filling in the type holes, and define the edges function as we have before:
{-# LANGUAGE FlexibleInstances #-}
instance DijkstraGraph (Graph String Int) where
type DijkstraNode (Graph String Int) = String
type DijkstraCost (Graph String Int) = Int
dijkstraEdges graph node = fromMaybe [] (HM.lookup node (edges graph))
This hasn’t fixed the “re-statement of parameters” issue I mentioned last time. In fact we now restate them twice. But with a more general type, we wouldn’t necessarily have to parameterize our Graph
type anymore.
Updating the Type Signature
The last thing to do now is update the type signatures within our function. Instead of using separate graph
, node
, and cost
parameters, we’ll just use one parameter g
for the graph, and define everything in terms of g
, DijkstraNode g
, and DijkstraCost g
.
First, let’s remind ourselves what this looked like for the multi-param typeclass version:
findShortestDistance ::
forall graph node cost. (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph node cost) =>
graph -> node -> node -> Distance cost
And now with the type family version:
findShortestDistance ::
forall g. (Hashable (DijkstraNode g), Eq (DijkstraNode g), Num (DijkstraCost g), Ord (DijkstraCost g), DijkstraGraph g) =>
g -> DijkstraNode g -> DijkstraNode g -> Distance (DijkstraCost g)
The “simpler” typeclass ends up being shorter. But the type family version actually has fewer type arguments (just g
instead of graph
, node
, and cost
). It’s up to you which you prefer.
And don’t forget, type signatures within the function will also need to change:
processQueue ::
DijkstraState (DijkstraNode g) (DijkstraCost g) ->
HashMap (DijkstraNode g) (Distance (DijkstraCost g))
Aside from that, the rest of this function works!
graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
[ ("A", [("D", 100), ("B", 1), ("C", 20)])
, ("B", [("D", 50)])
, ("C", [("D", 20)])
, ("D", [])
]
...
>> findShortestDistance graph1 “A” “D”
Dist 40
Unlike the multi-param typeclass version, this one has no need of specifying the final result type in the expression. Type inference seems to work better here.
Conclusion
Is this version better than the multi-param typeclass version? Maybe, maybe not, depending on your taste. It has some definite weaknesses in that type families are more of a foreign concept to more Haskellers, and they require more language extensions. The type signatures are also a bit more cumbersome. But, in my opinion, the semantics are more correct by making the graph type the primary target and the node and cost types as simply “associated” types.
In the next part of this series, we’ll apply these different approaches to some different graph types. This will demonstrate that the approach is truly general and can be used for many different problems!
Below you can find the full code, or you can follow this link to see everything on GitHub!
Appendix - Full Code
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Dijkstra3 where
import Data.Hashable (Hashable)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)
data Distance a = Dist a | Infinity
deriving (Show, Eq)
instance (Ord a) => Ord (Distance a) where
Infinity <= Infinity = True
Infinity <= Dist x = False
Dist x <= Infinity = True
Dist x <= Dist y = x <= y
addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity
(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)
newtype Graph node cost = Graph
{ edges :: HashMap node [(node, cost)] }
class DijkstraGraph graph where
type DijkstraNode graph :: *
type DijkstraCost graph :: *
dijkstraEdges :: graph -> DijkstraNode graph -> [(DijkstraNode graph, DijkstraCost graph)]
instance DijkstraGraph (Graph String Int) where
type DijkstraNode (Graph String Int) = String
type DijkstraCost (Graph String Int) = Int
dijkstraEdges graph node = fromMaybe [] (HM.lookup node (edges graph))
data DijkstraState node cost = DijkstraState
{ visitedSet :: HashSet node
, distanceMap :: HashMap node (Distance cost)
, nodeQueue :: MinPrioHeap (Distance cost) node
}
findShortestDistance :: forall g. (Hashable (DijkstraNode g), Eq (DijkstraNode g), Num (DijkstraCost g), Ord (DijkstraCost g), DijkstraGraph g) => g -> DijkstraNode g -> DijkstraNode g -> Distance (DijkstraCost g)
findShortestDistance graph src dest = processQueue initialState !?? dest
where
initialVisited = HS.empty
initialDistances = HM.singleton src (Dist 0)
initialQueue = H.fromList [(Dist 0, src)]
initialState = DijkstraState initialVisited initialDistances initialQueue
processQueue :: DijkstraState (DijkstraNode g) (DijkstraCost g) -> HashMap (DijkstraNode g) (Distance (DijkstraCost g))
processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
Nothing -> d0
Just ((minDist, node), q1) -> if node == dest then d0
else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
else
-- Update the visited set
let v1 = HS.insert node v0
-- Get all unvisited neighbors of our current node
allNeighbors = dijkstraEdges graph node
unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
-- Fold each neighbor and recursively process the queue
in processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
let altDistance = addDist (d0 !?? current) (Dist cost)
in if altDistance < d0 !?? neighborNode
then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
else ds
graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
[ ("A", [("D", 100), ("B", 1), ("C", 20)])
, ("B", [("D", 50)])
, ("C", [("D", 20)])
, ("D", [])
]