Generalizing Dijkstra's Algorithm
Earlier this week, I wrote a simplified implementation of Dijkstra’s algorithm. You can read the article for details or just look at the full code implementation here on GitHub. This implementation is fine to use within a particular project for a particular purpose, but it doesn’t generalize very well. Today we’ll explore how to make this idea more general.
I chose to do this without looking at any existing implementations of Dijkstra’s algorithm in Haskell libraries to see how my approach would be different. So at the end of this series I’ll also spend some time comparing my approach to some other ideas that exist in the Haskell world.
Parameterizing the Graph Type
So why doesn’t this approach generalize? Well, for the obvious reason that my module defines a specific type for the graph:
data Graph = Graph
{ graphEdges :: HashMap String [(String, Int)] }
So, for someone else to re-use this code from the perspective of a different project, they would have to take whatever graph information they had, and turn it into this specific type. And their data might not map very well into String
values for the nodes, and they might also have a different cost type in mind than the simple Int
value, with Double
being the most obvious example.
So we could parameterize the graph type and allow more customization of the underlying values.
data Graph node cost = Graph
{ graphEdges :: HashMap node [(node, cost)] }
The function signature would have to change to reflect this, and we would have to impose some additional constraints on these types:
findShortestDistance :: (Hashable node, Eq node, Num cost, Ord cost) =>
Graph node cost -> node -> node -> Distance cost
Graph Enumeration
But this would still leave us with an important problem. Sometimes you don’t want to have to enumerate the whole graph. As is, the expression you submit as the "graph" to the function must have every edge enumerated, or it won’t give you the right answer. But many times, you won’t want to list every edge because they are so numerous. Rather, you want to be able to list every edge simply from a particular node. For example:
edgesForNode :: Graph node cost -> node -> [(node, cost)]
How can we capture this behavior more generally in Haskell?
Using a Typeclass
Well one of the tools we typically turn to for this task is a typeclass. We might want to define something like this:
class DijkstraGraph graph where
dijkstraEdges :: graph node cost -> node -> [(node, cost)]
However, it quickly gets a bit strange to try to do this with a simple typeclass because of the node
and cost
parameters. It’s difficult to resolve the constraints we end up needing because these parameters aren’t really part of our class.
Using a Multi-Param Typeclass
We could instead try having a multi-param typeclass like this:
{-# LANGUAGE MultiParamTypeClasses #-}
class DijkstraGraph graph node cost where
dijkstraEdges :: graph -> node -> [(node, cost)]
This actually works more smoothly than the previous try. We can construct an instance (if we allow flexible instances).
{-# LANGUAGE FlexibleInstances #-}
import qualified Data.HashMap as HM
import Data.Maybe (fromMaybe)
instance DijkstraGraph (Graph String Int) String Int where
dijkstraEdges g n = fromMaybe [] (HM.lookup n (edges g))
And we can actually use this class in our function now! It mainly requires changing around a few of our type signatures. We can start with our DijkstraState
type, which must now be parameterized by the node
and cost
:
data DijkstraState node cost = DijkstraState
{ visitedSet :: HashSet node
, distanceMap :: HashMap node (Distance cost)
, nodeQueue :: MinPrioHeap (Distance cost) node
}
And, of course, we would also like to generalize the type signature of our findShortestDistance
function. In its simplest form, we would like use this:
findShortestDistance :: graph -> node -> node -> Distance cost
However, a couple extra items are necessary to make this work. First, as above, our function is the correct place to assign constraints to the node
and cost
types. The node type must fit into our hashing structures, so it should fulfill Eq
and Hashable
. The cost type must be Ord
and Num
in order for us to perform our addition operations and use it for the heap. And last of course, we have to add the constraint regarding the DijkstraGraph
itself:
findShortestDistance ::
(Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph) =>
graph -> node -> node -> Distance cost
Now, if we want to use the graph
, node
, and cost
types within the “inner” type signatures of our function, we need one more thing. We need a forall
specifier on the function so that the compiler knows we are referring to the same types.
{-# LANGUAGE ScopedTypeVariables #-}
findShortestDistance :: forall graph node cost.
(Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph) =>
graph -> node -> node -> Distance cost
We can now make one change to our function so that it works with our class.
processQueue :: DijkstraState node cost -> HashMap node (Distance cost)
processQueue = ...
-- Previously
-- allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
-- Updated
allNeighbors = dijkstraEdges graph node
And now we’re done! We can again, verify the behavior. However, we do run into some difficulties in that we need some extra type specifiers to help the compiler figure everything out.
graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
[ ("A", [("D", 100), ("B", 1), ("C", 20)])
, ("B", [("D", 50)])
, ("C", [("D", 20)])
, ("D", [])
]
...
>> :set -XFlexibleContexts
>> findShortestDistance graph1 :: Distance Int
Dist 40
Conclusion
Below in the appendix is the full code for this part. You can also take a look at it on Github here.
For various reasons, I don’t love this attempt at generalizing. I especially don't like the "re-statement" of the parameter types in the instance. The parameters are part of the Graph
type and are separately parameters of the class. This is what leads to the necessity of specifying the Distance Int
type in the GHCI session above. We could avoid this if we don't parameterize our Graph
type, which is definitely an option.
In the next part of this series, we'll make a second attempt at generalizing this algorithm!
Appendix
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Dijkstra2 where
import Data.Hashable (Hashable)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)
data Distance a = Dist a | Infinity
deriving (Show, Eq)
instance (Ord a) => Ord (Distance a) where
Infinity <= Infinity = True
Infinity <= Dist x = False
Dist x <= Infinity = True
Dist x <= Dist y = x <= y
addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity
(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)
newtype Graph node cost = Graph
{ edges :: HashMap node [(node, cost)] }
class DijkstraGraph graph node cost where
dijkstraEdges :: graph -> node -> [(node, cost)]
instance DijkstraGraph (Graph String Int) String Int where
dijkstraEdges g n = fromMaybe [] (HM.lookup n (edges g))
data DijkstraState node cost = DijkstraState
{ visitedSet :: HashSet node
, distanceMap :: HashMap node (Distance cost)
, nodeQueue :: MinPrioHeap (Distance cost) node
}
findShortestDistance :: forall graph node cost. (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph node cost) => graph -> node -> node -> Distance cost
findShortestDistance graph src dest = processQueue initialState !?? dest
where
initialVisited = HS.empty
initialDistances = HM.singleton src (Dist 0)
initialQueue = H.fromList [(Dist 0, src)]
initialState = DijkstraState initialVisited initialDistances initialQueue
processQueue :: DijkstraState node cost -> HashMap node (Distance cost)
processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
Nothing -> d0
Just ((minDist, node), q1) -> if node == dest then d0
else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
else
-- Update the visited set
let v1 = HS.insert node v0
-- Get all unvisited neighbors of our current node
allNeighbors = dijkstraEdges graph node
unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
-- Fold each neighbor and recursively process the queue
in processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
let altDistance = addDist (d0 !?? current) (Dist cost)
in if altDistance < d0 !?? neighborNode
then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
else ds
graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
[ ("A", [("D", 100), ("B", 1), ("C", 20)])
, ("B", [("D", 50)])
, ("C", [("D", 20)])
, ("D", [])
]