Dijkstra's Algorithm in Haskell

In some of my recent streaming sessions (some of which you can see on my YouTube chanel), I spent some time playing around with Dijkstra’s algorithm. I wrote my own version of it in Haskell, tried to generalize it to work in different settings, and then used it in some examples. So for the next couple weeks I’ll be writing about those results. Today I’ll start though with a quick overview of a basic Haskell approach to the problem.

Note: This article will follow the “In Depth” reading style I talked about last week. I’ll be including all the details of my code, so if you want to follow along with this article, everything should compile and work! I’ll list dependencies, imports, and the complete code in an appendix at the end.

Pseudocode

Before we can understand how to write this algorithm in Haskell specifically, we need to take a quick look at the pseudo code. This is adapted from the Wikipedia description

function Dijkstra(Graph, source):

      for each vertex v in Graph.Vertices:
          dist[v] <- INFINITY
          add v to Q
      dist[source] <- 0

      while Q is not empty:
          u <- vertex in Q with min dist[u]
          remove u from Q

          for each neighbor v of u still in Q:
              alt <- dist[u] + Graph.Edges(u, v)
              if alt < dist[v] and dist[u] is not INFINITY:
                  dist[v] <- alt

      return dist[]

There are a few noteworthy items here. This code references two main structures. We have dist, the mapping of nodes to distances. There is also Q, which has a special operation “vertex in Q with min dist[u]”. It also has the operation “still in Q”. We can actually separate this into two items. We can have one structure to track the minimum distance to nodes, and then we have a second to track which ones are “visited”.

With this in mind, we can break the "Dijkstra Process" into 5 distinct steps.

  1. Define our type signature
  2. Initialize a structure with the different items (Q, dist, etc.) in their initial states
  3. Write a loop for processing each element from the queue.
  4. Write an inner loop for processing each “neighbor” we encounter of the items pulled from the queue.
  5. Get our answer from the final structure.

These steps will help us organize our code. Before we dive into the algorithm itself though, we’ll want a few helpers!

Helpers

There are a few specific items that aren’t really part of the algorithm, but they’ll make things a lot smoother for us. First, we’re going to define a new “Distance” type. This will include a general numeric type for the distance but also include an “Infinity” constructor, which will help us represent the idea of an unreachable value.

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

We want to ensure this type has a reasonable ordering, and that we can add values with it in a sensible way. If we rely on a simple integer type maxBound, we’ll end up with arithmetic overflow problems.

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

Now we’ll be tracking our distances with a Hash Map. So as an extra convenience, we’ll add an operator that will look up a particular item in our map, returning its distance if that exists, but otherwise returning “Infinity” if it does not.

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

Type Signature

Now we move on to the next step in our process: defining a type signature. To start, we need to ask, "what kind of graph are we working with?" We’ll generalize this in the future, but for now let’s assume we are defining our graph entirely as a map of “Nodes” (represented by string names) to “Edges”, which are tuples of names and costs for the distance.

newtype Graph = Graph
   { edges :: HashMap String [(String, Int)] }

Our dijkstra function will take such a graph, a starting point (the source), and an ending point (the destination) as its parameters. It will return the “Distance” from the start to the end (which could be infinite). For now, we’ll exclude returning the full path, but we’ll get to that by the end of the series.

findShortestDistance :: Graph -> String -> String -> Distance Int

With our graph defined more clearly now, we’ll want to define one more “stateful” type to use within our algorithm. From reading the pseudo code, we want this to contain three structures that vary with each iteration.

  1. A set of nodes we’ve visited.
  2. The distance map, from nodes to their “current” distance values
  3. A queue allowing us to find the “unvisited node with the smallest current distance”.

For the first two items, it's straightforward to see what types we use. A HashSet of strings will suffice for the visited set, and likewise a HashMap will help us track distances. For the queue, we need to be a little clever, but a priority heap using the distance and the node will be most helpful.

data DijkstraState = DijkstraState
  { visitedSet :: HashSet String
  , distanceMap :: HashMap String (Distance Int)
  , nodeQueue :: MinPrioHeap (Distance Int) String
  }

Initializing Values

Now for step 2, let’s build our initial state, taking the form of the DijkstraState type we defined above. Initially, we will consider that there are no visited nodes. Then we’ll define that the only distance we have is a distance of “0” to the source node. We’ll also want to store this pair in the queue, so that the source is the first node we pull out of our queue.

findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph src dest = ...
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue
    ...

Processing the Queue

Now we’ll write a looping function that will process the elements in our queue. This function will return for us the mapping of nodes to distances. It will take the current DijkstraState as its input. Remember that the most basic method we have of looping in Haskell, particularly when it comes to a “while” loop, is a recursive function.

Every recursive function needs at least one base case. So let’s start with one of those. If the queue is empty, we can return the map as it is.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      ...

Next there are cases of the queue containing at least one element. Suppose this element is our destination. We can also return the distance map immediately here, as it will already contain the distance to that point.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else ...

One last base case: if the node is already visited, then we can immediately recurse, except plugging in the new queue q1 for the old queue.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else ...

Now, on to the recursive case. In this case we will do 3 things.

  1. Pull a new node from our heap and consider that node “visited”
  2. Get all the “neighbors” of this node
  3. Process each neighbor and update its distance

Most of the work in step 3 will happen in our “inner loop”. The basics for the first two steps are quite easy.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors

Now we just need to process each neighbor. We can do this using a fold. Our “folding function” will have a type signature that incorporates the current node as a “fixed” argument while otherwise following the a -> b -> a pattern of a left fold. Each step will incorporate a new node with its cost and update the DijkstraState. This means the a value in our folding function is DijkstraState, while the b value is (String, Int).

foldNeighbor :: String -> DijkstraState -> (String, Int) -> DijkstraState

With this type signature set, we can now complete our processQueue function before implementing this inner loop. We call a foldl over the new neighbors, and then we recurse over the resulting DijkstraState.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      ...
        else
          -- Update the visited set
          let v1 = HS.insert coord v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors

The Final Fold

Now let’s write this final fold, our “inner loop” function foldNeighbor. The core job of this function is to calculate the “alternative” distance to the given neighbor by going “through” the current node. This consists of taking the distance from the source to the current node (which is stored in the distance map) and adding it to the specific edge cost from the current to this new node.

foldNeighbor :: String -> DijkstraState -> (String, Int) -> DijkstraState
foldNeighbor current (DijkstraState v1 d0 q1) (neighborNode, cost) =
  let altDistance = addDist (d0 !?? current) (Dist cost)
  ...

We can then compare this distance to the existing distance we have to the neighbor in our distance map (or Infinity if it doesn’t exist, remember).

foldNeighbor current ds@(DijkstraState _ d0 _) (neighborNode, cost) =
  let altDistance = addDist (d0 !?? current) (Dist cost)
  in  if altDistance < d0 !?? neighborNode
  ...

If the alternative distance is smaller, we update the distance map by associating the neighbor node with the alternative distance and return the new DijkstraState. We also insert the new distance into our queue. If the alternative distance is not better, we make no changes, and return the original state.

foldNeighbor current ds@(DijkstraState _ d0 _) (neighborNode, cost) =
  let altDistance = addDist (d0 !?? current) (Dist cost)
  in  if altDistance < d0 !?? neighborNode
        then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
        else ds

Now we are essentially done! All that’s left to do is run the process queue function from the top level and get the distance for the destination point.

findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialState = ...
    processQueue = ...

Our code is complete, and we can construct a simple example and see that it works!

graph1 :: Graph
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

...

{- GHCI -}
>> findShortestDistance graph1 “A” “D”
40

Next Steps

This algorithm is nice, but it’s very limited and specific to how we’ve constructed the Graph type. What if we wanted to expose a more general API to users?

Read on to see how to do this in the next part of this series!

Appendix

Here is the complete code for this article, including imports!

module DijkstraSimple where

{- required packages:
   containers, unordered-containers, hashable
-}

import qualified Data.HashMap.Strict as HM
import Data.HashMap.Strict (HashMap)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.Hashable (Hashable)
import Data.Maybe (fromMaybe)

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

newtype Graph = Graph
   { edges :: HashMap String [(String, Int)] }

data DijkstraState = DijkstraState
  { visitedSet :: HashSet String
  , distanceMap :: HashMap String (Distance Int)
  , nodeQueue :: MinPrioHeap (Distance Int) String
  }


findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue

    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
    foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
      let altDistance = addDist (d0 !?? current) (Dist cost)
      in  if altDistance < d0 !?? neighborNode
            then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
            else ds

graph1 :: Graph
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]
Previous
Previous

Generalizing Dijkstra's Algorithm

Next
Next

Reading Style Results!