James Bowen James Bowen

An Unusual Application for Dijkstra

Today will be the final write-up for a 2021 Advent of Code problem. It will also serve as a capstone for the work on Dijkstra's algorithm I did back in the summer! This problem uses Dijkstra's algorithm, but in a more unusual way! We'll be working on Day 23 from last year. And for my part, I'll say that days 21-24 were all extremely challenging, so this is one of the "final boss" puzzles!

Like our previous write-ups, this is an In-Depth walkthrough, and it's a long one! So get ready for some details! The code is available on GitHub as always so you can follow along.

Problem Statement

For this puzzle, we start with a set of tokens divided into 4 rooms with a hallway allowing them to move around.

#############
#...........#
###B#C#B#D###
  #A#D#C#A#
  #########

Our goal is to rearrange the tokens so that the A tokens are both in the first room, the B tokens are in the second room, the C tokens are in the third room, and the D tokens are in the fourth room.

#############
#...........#
###A#B#C#D###
  #A#B#C#D#
  #########

However, there are a lot of restrictions on the possible moves. First, token's can't move past each other in the hall (or rooms). If D comes out of the fourth room first, we cannot then move the A in that room anywhere to the left. It could only go to a space on the right.

#############
#.......D...#
###B#C#B#.###
  #A#D#C#A#
  #########

Next, each token can only make two moves total. It can move into the hallway once, and then into its appropriate room. It can't take a side journey into a different room to make space for other tokens to pass.

On top of this, each token spends a certain amount of "power" (or "energy") to move per space. The different tokens spend a different amount of energy:

A = 1
B = 10
C = 100
D = 1000

So from the start position, we could spend 2000 energy to move D up to the right, and then only 9 energy to move A all the way to the left side.

#############
#.A.......D.#
###B#C#B#.###
  #A#D#C#.#
  #########

Our goal is to get the desired configuration with the least amount of energy expended.

For the "harder" version of this problem, not much changes. We just have 4 tokens per room, so more maneuvering steps are required.

#############
#...........#
###B#C#B#D###
  #D#C#B#A#
  #D#B#A#C#
  #A#D#C#A#
  #########

Solution Approach

The surprising solution approach (at least I was surprised when I realized it could work), is to treat this like a graph problem. Each "state" of the puzzle represents a node in the graph. Any given state has "edges" representing transitions to future states of the puzzle. The edges are weighted by how much energy is required in the transition.

Once we view the problem in this way, the solution is simple. We apply a "shortest path" algorithm (like Dijkstra's) using the "end" state of the puzzle as the destination. We'll get the series of moves that uses the least total energy.

For example, the first starting solution would represent one node. It would have an edge to this following puzzle state, with a weight of 2000, since a D is moving two spaces.

#############
#.........D.#
###B#C#B#.###
  #A#D#C#A#
  #########

There are some potential questions about the scale of this problem. If the potential number of nodes is too high, even Dijkstra's algorithm could take too long. And if the tokens could be placed arbitrarily anywhere in the puzzle space, our upper bound might be a factorial number like 23-P-16. This would be too large.

However, as a practical matter, the solution space is much smaller than this because of the many restrictions on how tokens can actually move. So we'll end up with a solution space that is still large but not intractable.

Solution Outline

As we start to outline our solution, we need to start by considering which Dijkstra library function we'll use. In order to allow monadic actions in our functions (such as logging), we'll use dijkstraM, which has the following type signature:

dijkstraM ::
  (Monad m, Foldable f, Num cost, Ord cost, Ord state) =>
  (state -> m (f state)) ->
  (state -> state -> m cost) ->
  (state -> m bool) ->
  state ->
  m (Maybe (cost, [state]))

To make this work, we need to pick the types we'll use for state and cost. For the cost, we can rely on a simple Int. For the state, we'll create a custom GraphState type that will represent the state of the solution at a particular point in time.

data GraphState = ...

We'll expand more on exactly what information goes into this type as we go along. But now that we've defined our type, we can define the three functions that we'll use as inputs to dijkstraM:

getNeighbors :: (MonadLogger m) => GraphState -> m [GraphState]
getCost :: (MonadLogger m) => GraphState -> GraphState -> m Int
isComplete :: (MonadLogger m) => GraphState -> m Bool

We can (and will) add at least one more argument to partially apply, but still, this lets us outline what a full invocation of the function might look like:

solution :: (MonadLogger m) => GraphState -> m (Maybe (Int, [GraphState])
solution initialState = dijkstraM getNeighbors getCost isComplete initialState

Completeness Check

So now let's start filling in these functions. We'll start with the completeness check, since that's the easiest. Because this check is run fairly often, we want to make it as quick as possible. So instead of doing a full completeness check on the state each time we call it, we'll store a specific field in the graph state called roomsFull.

data GraphState = GraphState
  { roomsFull :: Int -- Initially 0, increments when we finish a room
  ...
  }

This field will be 0 when we initialize the state, and whenever we "complete" a room in our search path, we'll bump the number up. Checking for completeness then is as simple as checking that we've completed all 4 rooms.

isComplete :: (MonadLogger m) => GraphState -> m Bool
isComplete gs = return (roomsFull gs == 4)

Cost

It would be more convenient to combine the cost with the neighbors function, like in dijkstraAssoc. But we don't have this option if we want to use a monad. Calculating the cost between two raw graph states would be a little tricky, since we'd have to go through a lot of cases to see what has actually changed.

However, it gets easier if we include the "last move" as part of the GraphState type. So let's start defining what a Move looks like. To start, we'll include a NoMove constructor for the initial position, and we'll make a note that the GraphState will include this field.

data Move =
  NoMove |
  ...

data GraphState = GraphState
  { lastMove :: Move
  , roomsFull :: Int
  ...
  }

So how do we describe a move? Because the rules are so constrained, we can be sure every move has the following:

  1. A particular token that is moving.
  2. A particular "hall space" that it is moving to or from.
  3. A particular "room" that it is moving to or from.

Each of these concepts is easily enumerated, so let's make some Enum types that are also indexable (we'll see why soon):

data Token = A | B | C | D
  deriving (Show, Eq, Ord, Enum, Ix)

data Room = RA | RB | RC | RD
  deriving (Show, Eq, Ord, Enum, Ix)

-- Can never occupy spaces above the room like H3, H5, H7, H9
data HallSpace = H1 | H2 | H4 | H6 | H8 | H10 | H11
  deriving (Show, Eq, Ord, Enum, Ix)

Now we can describe the Move constructor with these three items, as well as two more pieces of data. First, an Int paired with the room describing the "slot" of the room involved. For example, the top "slot" of a room would be 1, the space below it would be 2, and so on. Finally, we'll include a Bool telling us if the move is leaving the room (True) or entering the room (False). This won't be necessarily for calculations, but it helps with debugging.

data Move =
  NoMove |
  Move Token HallSpace (Room, Int) Bool
  deriving (Show, Eq, Ord)

So what is the cost of a move? We have to calculate the distance, and we have to know the power multiplier. So let's make two constant arrays that we'll reference. First, let's match each token to its multiplier:

tokenPower :: A.Array Token Int
tokenPower = A.array (A, D) [(A, 1), (B, 10), (C, 100), (D, 1000)]

Now we want to match each pair of "hall space" and "room" with a distance measurement. This tells us how many moves it takes to get from the hall space to the space above the room. For example, the first hall space requires 2 moves to get to room A and 4 to get to room B, while the second space only requires 1 and 3 moves, respectively:

hallRoomDistance :: A.Array (HallSpace, Room) Int
hallRoomDistance = A.array ((H1, RA), (H11, RD))
  [ ((H1, RA), 2), ((H1, RB), 4), ((H1, RC), 6), ((H1, RD), 8)
  , ((H2, RA), 1), ((H2, RB), 3), ((H2, RC), 5), ((H2, RD), 7)
  ...
  ]

Here's what the complete array looks like:

hallRoomDistance :: A.Array (HallSpace, Room) Int
hallRoomDistance = A.array ((H1, RA), (H11, RD))
  [ ((H1, RA), 2), ((H1, RB), 4), ((H1, RC), 6), ((H1, RD), 8)
  , ((H2, RA), 1), ((H2, RB), 3), ((H2, RC), 5), ((H2, RD), 7)
  , ((H4, RA), 1), ((H4, RB), 1), ((H4, RC), 3), ((H4, RD), 5)
  , ((H6, RA), 3), ((H6, RB), 1), ((H6, RC), 1), ((H6, RD), 3)
  , ((H8, RA), 5), ((H8, RB), 3), ((H8, RC), 1), ((H8, RD), 1)
  , ((H10, RA), 7), ((H10, RB), 5), ((H10, RC), 3), ((H10, RD), 1)
  , ((H11, RA), 8), ((H11, RB), 6), ((H11, RC), 4), ((H11, RD), 2)
  ]

Now calculating the cost is fairly straightforward. We get the distance to the room, add the slot within the room, and then multiply this by the power multiplier.

getCost :: (MonadLogger m) => GraphState -> GraphState -> m Int
getCost _ gs = if lastMove gs == NoMove
  then return 0
  else do
    let (Move token hs (rm, slot) _) = lastMove gs
    let mult = tokenPower A.! token
    let distance = slot + hallRoomDistance A.! (hs, rm)
    return $ mult * distance

Finishing the Graph State

Our solution is starting to take on a bit more shape, but we need to complete our GraphState type before we can make further progress. But now armed with the notion of a Token, we can fill in the remaining fields that describe it. Each room has a list of tokens that are currently residing there. And then each hall space either has a token there or not, so we have Maybe Token fields for them.

data GraphState = GraphState
  { lastMove :: Move
  , roomsFull :: Int
  , roomA :: [Token]
  , roomB :: [Token]
  , roomC :: [Token]
  , roomD :: [Token]
  , hall1 :: Maybe Token
  , hall2 :: Maybe Token
  , hall4 :: Maybe Token
  , hall6 :: Maybe Token
  , hall8 :: Maybe Token
  , hall10 :: Maybe Token
  , hall11 :: Maybe Token
  }
  deriving (Show, Eq, Ord)

Sometimes it will be useful for us to access parts of the state in a general way. We might want a function to access "one of the rooms" or "one of the hall spaces". Some day, I might revise my solution to use proper Haskell "Lenses", which would be ideal for this problem. But for now we'll define a couple simple type aliases for a RoomLens to access the tokens in a general room, and a HallLens for looking at a general hall space.

type RoomLens = GraphState -> [Token]
type HallLens = GraphState -> Maybe Token

One last piece of boilerplate we'll want will be to have "split lists" for each room. Each of these is a tuple of two lists. The first list is the hall spaces to the "left" of that room, and the second has the hall spaces to the "right" of the room.

These lists will help us answer questions like, "how many empty hall spaces can we move to from this room moving left?", or "what is the first token to the right of this room?" For these to be useful, each hall space should also include the "lens" into the GraphState, so we can examine what token lives there.

For example, room A has H2 and H1 to its left (in that order), and then H4, H6, H8, H10 and H11 to its right. We'll match each HallSpace with its HallLens, so H1 combines with the hall1 field from GraphState, and so on.

aSplits :: ([(HallLens, HallSpace)], [(HallLens, HallSpace)])
aSplits =
  ( [(hall2, H2), (hall1, H1)]
  , [(hall4, H4), (hall6, H6), (hall8, H8), (hall10, H10), (hall11, H11)]
  )

Here's what the rest of those look like:

bSplits :: ([(HallLens, HallSpace)], [(HallLens, HallSpace)])
bSplits =
  ( [(hall4, H4), (hall2, H2), (hall1, H1)]
  , [(hall6, H6), (hall8, H8), (hall10, H10), (hall11, H11)]
  )

cSplits :: ([(HallLens, HallSpace)], [(HallLens, HallSpace)])
cSplits =
  ( [(hall6, H6), (hall4, H4), (hall2, H2), (hall1, H1)]
  , [(hall8, H8), (hall10, H10), (hall11, H11)]
  )

dSplits :: ([(HallLens, HallSpace)], [(HallLens, HallSpace)])
dSplits =
  ( [(hall8, H8), (hall6, H6), (hall4, H4), (hall2, H2), (hall1, H1)]
  , [(hall10, H10), (hall11, H11)]
  )

It would be easy enough to use a common function with splitAt to describe all of these. But once again, we'll reference these many times throughout the solution, so using constants instead of requiring function logic could help make our code faster.

Moves from a Particular Room

Now it's time for the third and largest piece of the puzzle: calculating the "next" states, or the "neighboring" states of a particular graph state. This means determining what moves are possible from a particular position. This is a complex problem that we'll have to keep breaking down into smaller and smaller parts.

We can first observe that every move involves one room and the hallway - there are no moves from room to room. So we can divide the work by considering all the moves concerning one particular room. Then there are three cases for each room:

  1. The room is complete; it is full of the appropriate token.
  2. The room is empty or partially full of the appropriate token.
  3. The room has mismatched tokens inside.

In case 1, we'll propose no moves involving this room. In case 2, we will try to find the appropriate token in the hall and bring it into the room (from either direction). In case 3, we will consider all the ways to move a token out of the room.

We'll do all this in a general function roomMoves. This function needs to know the room size, the appropriate token for the room, the appropriate lens for accessing the room, and finally, the split list corresponding to the room. This leads to a long type signature, but each parameter has its role:

roomMoves ::
  (MonadLogger m) =>
  Int ->
  Token ->
  Room ->
  RoomLens ->
  ([(HallLens, HallSpace)], [(HallLens, HallSpace)]) ->
  GraphState ->
  m [GraphState]
roomMoves rs tok rm roomLens splits gs = ...

For getNeighbors, all we have to do is invoke this function once for each room and combine the results.

getNeighbors :: (MonadLogger m) => Int -> GraphState -> m [GraphState]
getNeighbors rs gs = do
  arm <- roomMoves rs A RA roomA aSplits gs
  brm <- roomMoves rs B RB roomB bSplits gs
  crm <- roomMoves rs C RC roomC cSplits gs
  drm <- roomMoves rs D RD roomD dSplits gs
  return $ arm <> brm <> crm <> drm

Now back to roomMoves. Let's start by defining the three cases mentioned above. The first case is easy to complete.

roomMoves rs tok rm roomLens splits gs 
  | roomLens gs == replicate rs tok = return []
  | all (== tok) (roomLens gs) = ...
  | otherwise = ...

Now let's consider the second case. We want to search each direction from this room to try to find a hall space containing the matching token. We can do this with a recursive helper function. In the base case, we're out of hall spaces to search, so we return Nothing:

findX :: Token -> GraphState -> [(HallLens, HallSpace)] -> Maybe HallSpace
findX _ _ [] = Nothing
findX tok gs ((lens, space) : rest) = ...

Then there are three simple cases for what to do with the next space. If we have an instance of the token, return the space. If we have a different token, the answer is Nothing (we are blocked). If there is no token there, we continue the search recursively.

findX :: Token -> GraphState -> [(HallLens, HallSpace)] -> Maybe HallSpace
findX _ _ [] = Nothing
findX tok gs ((lens, space) : rest)
  | lens gs == Just tok = Just space
  | isJust (lens gs) = Nothing
  | otherwise = findX tok gs rest

Using our split lists, we can find the potential spaces on the left and the right by applying our findX helper.

roomMoves rs tok rm roomLens splits gs 
  | roomLens gs == replicate rs tok = return []
  | all (== tok) (roomLens gs) = do
    let maybeLeft = findX tok gs (fst splits)
        maybeRight = findX tok gs (snd splits)
        halls = catMaybes [maybeLeft, maybeRight]
        ...
  | otherwise = ...

For right now, let's just worry about constructing the Move object. Later on, we'll fill out a function to apply this move:

applyHallMove :: Int -> Token -> RoomLens -> GraphState -> Move -> GraphState

So to finish the case, we get the "slot" number to move to by considering the length of the room currently. Then we construct the Move, and apply it against our two possible outcomes.

roomMoves rs tok rm roomLens splits gs 
  | roomLens gs == replicate rs tok = return []
  | all (== tok) (roomLens gs) = do
    let maybeLeft = findX tok gs (fst splits)
        maybeRight = findX tok gs (snd splits)
        halls = catMaybes [maybeLeft, maybeRight]
        slot = rs - length (roomLens gs)
        moves = map (\h -> Move tok h (rm, slot) False) halls
    return $ map (applyHallMove rs tok roomLens gs) moves
  | otherwise = ...

Moves Out of the Room

Now let's consider the third case - moving a token out of a room. This requires finding as many consecutive "empty" hall spaces in each direction as we can. This will be another recursive helper like findX:

findEmptyHalls :: GraphState -> [(HallLens, HallSpace)] -> [HallSpace] -> [HallSpace]
findEmptyHalls _ [] accum = accum
findEmptyHalls gs ((lens, space) : rest) accum = ...

Once we hit a Just token value in the graph state, we can return our accumulated list. But otherwise we keep recursing.

findEmptyHalls :: GraphState -> [(HallLens, HallSpace)] -> [HallSpace] -> [HallSpace]
findEmptyHalls _ [] accum = accum
findEmptyHalls gs ((lens, space) : rest) accum = if isJust (lens gs) then accum
  else findEmptyHalls gs rest (space : accum)

Now we can apply this back in our roomMoves function with both sides of the splits.

roomMoves rs tok rm roomLens splits gs 
  | roomLens gs == replicate rs tok = return []
  | all (== tok) (roomLens gs) = ...
  | otherwise = do
    let (topRoom : restRoom) = roomLens gs
        halls = findEmptyHalls gs (fst splits) [] <> findEmptyHalls gs (snd splits) []
        ...

Once again then, we calculate the "slot" value and construct the new move using each of the hall spaces. Notice that the slot calculation is different. We want to subtract the length of the "rest" of the room from the room size, since this gives the appropriate slot value.

roomMoves rs tok rm roomLens splits gs 
  | roomLens gs == replicate rs tok = return []
  | all (== tok) (roomLens gs) = ...
  | otherwise = do
    let (topRoom : restRoom) = roomLens gs
        halls = findEmptyHalls gs (fst splits) [] <> findEmptyHalls gs (snd splits) []
        slot = rs - length restRoom
        moves = map (\h -> Move topRoom h (rm, slot) True) halls
    ...

Then, as before, we'll assume we have a helper to "apply" the move, and return the new graph states. Notice this time, we set the move flag as True, since the move is coming out of the room.

applyRoomMove :: GraphState -> Token -> Move -> GraphState
applyRoomMove = ...

roomMoves rs tok rm roomLens splits gs 
  | roomLens gs == replicate rs tok = return []
  | all (== tok) (roomLens gs) = ...
| otherwise = do
    let (topRoom : restRoom) = roomLens gs
        halls = findEmptyHalls gs (fst splits) [] <> findEmptyHalls gs (snd splits) []
        slot = rs - length restRoom
        moves = map (\h -> Move topRoom h (rm, slot) True) halls
    return $ map (applyRoomMove gs tok) moves

Now let's work on these two "apply" helpers. Each will take the current state and the Move and construct the new GraphState.

Applying moves

We'll start by applying the move from the room. Of course, for the NoMove case, we return the original state.

applyRoomMove :: GraphState -> Move -> GraphState
applyRoomMove gs NoMove = gs
applyRoomMove gs m@(Move token h (rm, slot) _) = ...

Now with all our new information, we'll update the GraphState in two stages, because this will require two case statements. First, we'll update the hall space to contain the moved token. We'll also place the move m into the lastMove spot.

applyRoomMove :: GraphState -> Move -> GraphState
applyRoomMove gs NoMove = gs
applyRoomMove gs m@(Move token h (rm, slot) _) =
  let gs2 = case h of
        H1 -> gs {hall1 = Just token, lastMove = m}
        H2 -> gs {hall2 = Just token, lastMove = m}
        H4 -> gs {hall4 = Just token, lastMove = m}
        H6 -> gs {hall6 = Just token, lastMove = m}
        H8 -> gs {hall8 = Just token, lastMove = m}
        H10 -> gs {hall10 = Just token, lastMove = m}
        H11 -> gs {hall11 = Just token, lastMove = m}
  in  ...

Now we need to modify the room to drop the top token. Unfortunately, we can't actually use a RoomLens argument in conjunction with record syntax updating, so this needs to be a case statement as well. With proper lenses, we could probably simplify this.

applyRoomMove :: GraphState -> Token -> Move -> GraphState
applyRoomMove gs roomToken NoMove = gs
applyRoomMove gs roomToken m@(Move token h (rm, slot) _) =
  let gs2 = case h of
        H1 -> gs {hall1 = Just token, lastMove = m}
        H2 -> gs {hall2 = Just token, lastMove = m}
        H4 -> gs {hall4 = Just token, lastMove = m}
        H6 -> gs {hall6 = Just token, lastMove = m}
        H8 -> gs {hall8 = Just token, lastMove = m}
        H10 -> gs {hall10 = Just token, lastMove = m}
        H11 -> gs {hall11 = Just token, lastMove = m}
  in  case rm of
    RA -> gs2 { roomA = tail (roomA gs)}
    RB -> gs2 { roomB = tail (roomB gs)}
    RC -> gs2 { roomC = tail (roomC gs)}
    RD -> gs2 { roomD = tail (roomD gs)}

That's all for applying a move from the room. Applying a move from the hall into the room is similar. But we have the extra task of determining if the destination room is now complete. So in this case we actually can make use of the RoomLens.

applyHallMove :: Int -> RoomLens -> GraphState -> Move -> GraphState
applyHallMove rs roomLens gs NoMove = gs
applyHallMove rs roomLens gs m@(Move token h (rm, slot) _) = ...

As before, we start by updating the hall space (now it's Nothing) and the lastMove field. We'll also update the finishedCount on this update step.

applyHallMove :: Int -> RoomLens -> GraphState -> Move -> GraphState
applyHallMove rs roomLens gs NoMove = gs
applyHallMove rs roomLens gs m@(Move token h (rm, slot) _) =
  let gs2 = case h of
        H1 -> gs {hall1 = Nothing, lastMove = m, roomsFull = finishedCount}
        H2 -> gs {hall2 = Nothing, lastMove = m, roomsFull = finishedCount}
        H4 -> gs {hall4 = Nothing, lastMove = m, roomsFull = finishedCount}
        H6 -> gs {hall6 = Nothing, lastMove = m, roomsFull = finishedCount}
        H8 -> gs {hall8 = Nothing, lastMove = m, roomsFull = finishedCount}
        H10 -> gs {hall10 = Nothing, lastMove = m, roomsFull = finishedCount}
        H11 -> gs {hall11 = Nothing, lastMove = m, roomsFull = finishedCount}
  in  ...
  where
    finishedCount = ...

How do we implement the finishedCount? It's not too difficult. We can easily assess if it's finished by checking the roomLens on the original state and seeing if it's equal to "Room Size minus 1". Then the finished count increments if this is true.

applyHallMove :: Int -> RoomLens -> GraphState -> Move -> GraphState
applyHallMove rs roomLens gs NoMove = gs
applyHallMove rs roomLens gs m@(Move token h (rm, slot) _) =
  let gs2 = case h of
        H1 -> gs {hall1 = Nothing, lastMove = m, roomsFull = finishedCount}
        H2 -> gs {hall2 = Nothing, lastMove = m, roomsFull = finishedCount}
        H4 -> gs {hall4 = Nothing, lastMove = m, roomsFull = finishedCount}
        H6 -> gs {hall6 = Nothing, lastMove = m, roomsFull = finishedCount}
        H8 -> gs {hall8 = Nothing, lastMove = m, roomsFull = finishedCount}
        H10 -> gs {hall10 = Nothing, lastMove = m, roomsFull = finishedCount}
        H11 -> gs {hall11 = Nothing, lastMove = m, roomsFull = finishedCount}
  in  ...
  where
    finished = length (roomLens gs) == rs - 1
    finishedCount = if finished then roomsFull gs + 1 else roomsFull gs

Now we do the same concluding step as the room move, except this time we're adding the token to the room instead of removing it.

applyHallMove :: Int -> RoomLens -> GraphState -> Move -> GraphState
applyHallMove rs roomLens gs NoMove = gs
applyHallMove rs roomLens gs m@(Move token h (rm, slot) _) =
  let gs2 = case h of
        H1 -> gs {hall1 = Nothing, lastMove = m, roomsFull = finishedCount}
        H2 -> gs {hall2 = Nothing, lastMove = m, roomsFull = finishedCount}
        H4 -> gs {hall4 = Nothing, lastMove = m, roomsFull = finishedCount}
        H6 -> gs {hall6 = Nothing, lastMove = m, roomsFull = finishedCount}
        H8 -> gs {hall8 = Nothing, lastMove = m, roomsFull = finishedCount}
        H10 -> gs {hall10 = Nothing, lastMove = m, roomsFull = finishedCount}
        H11 -> gs {hall11 = Nothing, lastMove = m, roomsFull = finishedCount}
  in case rm of
    RA -> gs2 {roomA = A : roomA gs}
    RB -> gs2 {roomB = B : roomB gs}
    RC -> gs2 {roomC = C : roomC gs}
    RD -> gs2 {roomD = D : roomD gs}
  where
    finished = length (roomLens gs) == rs - 1
    finishedCount = if finished then roomsFull gs + 1 else roomsFull gs

Making the Initial State

That's the conclusion of the algorithm functions; now we just need some glue, such as the initial states and pulling it all together. For the first time with our Advent of Code problems, we don't actually need to parse an input file. We could go through this process, but the "hard" input is still basically the same size, so we can just define these initial states in code.

Let's recall that our basic case looks like this:

#############
#...........#
###B#C#B#D###
  #A#D#C#A#
  #########

We'll translate it into an initial state as:

initialState1 :: GraphState
initialState1 = GraphState
  NoMove 0 [B, A] [C, D] [B, C] [D, A]
  Nothing Nothing Nothing Nothing Nothing Nothing Nothing

Our slightly harder version has the same structure, just with letters in more unusual places.

{-
#############
#...........#
###C#A#D#D###
  #B#A#B#C#
  #########
-}

initialState2 :: GraphState
initialState2 = GraphState
  NoMove 0 [C, B] [A, A] [D, B] [D, C]
  Nothing Nothing Nothing Nothing Nothing Nothing Nothing

Now for the "hard" part of the problem, we increase the room size to 4, and insert additional characters into each room. This is what those states look like.

initialState3 :: GraphState
initialState3 = GraphState
  NoMove 0 [B, D, D, A] [C, C, B, D] [B, B, A, C] [D, A, C, A]
  Nothing Nothing Nothing Nothing Nothing Nothing Nothing

initialState4 :: GraphState
initialState4 = GraphState
  NoMove 0 [C, D, D, B] [A, C, B, A] [D, B, A, B] [D, A, C, C]
  Nothing Nothing Nothing Nothing Nothing Nothing Nothing

Solving the Problem

Now we can "solve" each of the problems. Our solution code is essentially the same for each side. The "hard" part just passes 4 as the room size.

solveDay23Easy :: GraphState -> IO (Maybe Int)
solveDay23Easy gs = runStdoutLoggingT $ do
   result <- dijkstraM (getNeighbors 2) getCost isComplete gs
   case result of
    Nothing -> return Nothing
    Just (d, path) -> return $ Just d

solveDay23Hard :: GraphState -> IO (Maybe Int)
solveDay23Hard gs = runStdoutLoggingT $ do
   result <- dijkstraM (getNeighbors 4) getCost isComplete gs
   case result of
    Nothing -> return Nothing
    Just (d, path) -> return $ Just d

And now our code is complete! We can run it and the total distance. It actually turns out to require less energy for the second case in each group:

First  Size-2: 12521
Second Size-2: 10526
First  Size-4: 44169
Second Size-4: 41284

It's also possible, if we want, to print out the "path" we took by considering the moves in each state!

solveDay23Easy :: GraphState -> IO (Maybe Int)
solveDay23Easy gs = runStdoutLoggingT $ do
   result <- dijkstraM (getNeighbors 2) getCost isComplete gs
   case result of
    Nothing -> return Nothing
    Just (d, path) -> do
      forM_ path $ \gs' -> logDebugN (pack . show $ lastMove gs')
      return $ Just d

Here's the path we take for this simple version! Remember that True moves come from the room into the hall, and False moves go from the hall back into the room.

[Debug] Move D H10 (RD,1) True
[Debug] Move A H2 (RD,2) True
[Debug] Move B H4 (RC,1) True
[Debug] Move C H6 (RB,1) True
[Debug] Move C H6 (RC,1) False
[Debug] Move D H8 (RB,2) True
[Debug] Move D H8 (RD,2) False
[Debug] Move D H10 (RD,1) False
[Debug] Move B H4 (RB,2) False
[Debug] Move B H4 (RA,1) True
[Debug] Move B H4 (RB,1) False
[Debug] Move A H2 (RA,1) False

As a final note, the scale of the search is fairly large but by no means intractable. My solution doesn't give an instant answer, but it returns within a minute or so.

Conclusion

That is all for our review of Advent of Code 2021! We'll have the video walkthrough later in the week. And then in a couple weeks, we'll be ready to start Advent of Code 2022, so stay tuned for that!

If you've enjoyed these tutorials, make sure to subscribe to our mailing list! We've got a big offer coming up next week that you won't want to miss!

Read More
James Bowen James Bowen

Zoom! Enhance!

Today we'll be tackling the Day 20 problem from Advent of Code 2021. This problem is a fun take on the Zoom and Enhance cliche from TV dramas where cops and spies can always seem to get unrealistic details from grainy camera footage by "enhancing" it. We'll have a binary image and we'll need to keep applying a decoding key to expand the image.

As always, you can see all the nitty gritty details of the code at once by going to the GitHub repository I've made for these problems. If you're enjoying these in-depth walkthroughs, make sure to subscribe so you can stay up to date with the latest news.

Problem Statement

Our problem input consists of a couple sections that have "binary" data, where the . character represents 0 and the # character represents 1.

..#.#..#####.#.#.#.###.##.....###.##.#..###.####..#####..#....#..#..##..##
#..######.###...####..#..#####..##..#.#####...##.#.#..#.##..#.#......#.###
.######.###.####...#.##.##..#..#..#####.....#.#....###..#.##......#.....#.
.#..#..##..#...##.######.####.####.#.#...#.......#..#.#.#...####.##.#.....
.#..#...##.#.##..#...##.#.##..###.#......#.#.......#.#.#.####.###.##...#..
...####.#..#..#.##.#....##..#.####....##...##..#...#......#.#.......#.....
..##..####..#...#.#.#...##..#.#..###..#####........#..####......#..#

#..#.
#....
##..#
..#..
..###

The first part (which actually would appear all on one line) is a 512 character decoding key. Why length 512? Well 512 = 2^9, and we'll see in a second why the ninth power is significant.

The second part of the input is a 2D "image", represented in binary. Our goal is to "enhance" the image using the decoding key. How do we enhance it?

To get the new value at a coordinate (x, y), we have to consider the value at that coordinate together with all 8 of its neighbors.

# . . # .
#[. . .].
#[# . .]#
.[. # .].
. . # # #

The brackets show every pixel that is involved in getting the new value at the "center" of our grid. The way we get the value is to line up these pixels in binary: ...#...#. = 000100010. Then we get the decimal value (34 in this case). This tells us new value comes from the 34th character in the decoder key, which is #. So this middle pixel will be "on" after the first expansion. Since each pixel expansion factors in 9 pixels, there are 2^9 = 512 possible values, hence the length of the decoding key.

All transformations happen simultaneously. What is noteworthy is that for "fringe" pixels we must account for the boundary outside the initial image. And in fact, our image then expands into this new region! The enhanced version of our first 5x5 image actually becomes size 7x7.

.##.##.
#..#.#.
##.#..#
####..#
.#..##.
..##..#
...#.#.

For the easy part, we'll do this expansion twice. For the hard part, we'll do it 50 times. Our puzzle answer is the number of pixels that are lit in the final iteration.

Solution Approach

At first glance, this problem is pretty straightforward. It's another "state evolution" problem where we take the problem in an initial state and write a function to evolve that state to the next step. Evolving a single step involves looking at the individual pixels, and applying a fairly simple algorithm to get the resulting pixel.

The ever-expanding range of coordinates is a little tricky. But if we use a structure that allows "negative" indices (and Haskell makes this easy!), it's not too bad.

But there's one BIG nuance though with how the "infinite" image works. We still have to implicitly imagine that the enhancement algorithm is applying to all the other pixels in "infinite space". You would hope that, since all those pixels are surrounded by other "off" pixels, they remain "off".

However, my "hard" puzzle input got a decoding key with # in the 0 position, meaning that "off" pixels surrounded by other "off" pixels all turn on! Luckily, the decoder also has . in the final position, meaning that these pixels turn "off" again on the next step. However, we need to account for this on/off pattern of all these "outside pixels" since they'll affect the pixels on the fringe of our solution.

To that end, we'll need to keep track of the value of outer pixels throughout our algorithm - I'll refer to this as the "outside bit". This will impact every layer of the solution!

So with that to look forward to, let's start coding!

Utilities

As always, a few utilities will benefit us. From last week's look at binary numbers, we'll use a couple helpers like the Bit type and a binary-to-decimal conversion function.

data Bit = Zero | One
  deriving (Eq, Ord)

bitsToDecimal64 :: [Bit] -> Word64

Another very useful idea is turning a nested list into a hash map. This helps simplify parsing a lot. We saw this function in the Day 11 Octopus Problem.

hashMapFromNestedLists :: [[a]] -> HashMap Coord2 a

Another idea from Day 11 was getting all 8 neighbors of a 2D coordinate. Originally, we did this with (0,0) as a hard lower bound. But we can expand this idea so that the grid bounds of the function are taken as inputs. So getNeighbors8Flex takes two additional coordinate parameters to help provide those bounds for us.

getNeighbors8Flex :: Coord2 -> Coord2 -> Coord2 -> [Coord2]
getNeighbors8Flex (minRow, minCol) (maxRow, maxCol) (row, col) = catMaybes
  [maybeUpLeft, maybeUp, maybeUpRight, maybeLeft, maybeRight, maybeDownLeft, maybeDown, maybeDownRight]
  where
    maybeUp = if row > minRow then Just (row - 1, col) else Nothing
    maybeUpRight = if row > minRow && col < maxCol then Just (row - 1, col + 1) else Nothing
    maybeRight = if col < maxCol then Just (row, col + 1) else Nothing
    maybeDownRight = if row < maxRow && col < maxCol then Just (row + 1, col + 1) else Nothing
    maybeDown = if row < maxRow then Just (row + 1, col) else Nothing
    maybeDownLeft = if row < maxRow && col > minCol then Just (row + 1, col - 1) else Nothing
    maybeLeft = if col > minCol then Just (row, col - 1) else Nothing
    maybeUpLeft = if row > minRow && col > minCol then Just (row - 1, col - 1) else Nothing

Of particular note is the way we order the results. This ordering (top, then same row, then bottom), will allow us to easily decode our values for this problem.

Another detail for this problem is that we'll just want to use "no bounds" on the coordinates with the minimum and maximum integers as the bounds.

getNeighbors8Unbounded :: Coord2 -> [Coord2]
getNeighbors8Unbounded = getNeighbors8Flex (minBound, minBound) (maxBound, maxBound)

Last but not least, we'll also rely on this old standby, the countWhere function, to quickly get the occurrence of certain values in a list.

countWhere :: (a -> Bool) -> [a] -> Int

Inputs

Like all Advent of Code problems, we'll start with parsing our input. We need to get everything into bits, but instead of 0 and 1 characters, we're dealing with the character . for off, and # for 1. So we start with a choice parser to get a single pixel.

parsePixel :: (MonadLogger m) => ParsecT Void Text m Bit
parsePixel = choice [char '.' >> return Zero, char '#' >> return One]

Now we need a couple types to organize our values. The decoder map will tell us a particular bit for every index from 0-511. So we can use a hash map with Word64 as the key.

type DecoderMap = HashMap Word64 Bit

Furthermore, it's easy to see how we build this decoder from a list of bits with a simple zip:

buildDecoder :: [Bit] -> DecoderMap
buildDecoder input = HM.fromList (zip [0..] input)

For the image though, we have 2D data. So let's using a hash map over Coord2 for our ImageMap type:

type ImageMap = HashMap Coord2 Bit

We have enough tools to start writing our function now. We'll parse an initial series of pixels and build the decoder out of them, followed by a couple eol characters.

parseInput :: (MonadLogger m) => ParsecT Void Text m (DecoderMap, ImageMap)
parseInput = do
  decoderMap <- buildDecoder <$> some parsePixel
  eol >> eol
  ...

Now we'll get the 2D image. We'll start by getting a nested list structure using the sepEndBy1 ... eol trick we've seen so many times already.

parse2DImage :: (MonadLogger m) => ParsecT Void Text m [[Bit]]
parse2DImage = sepEndBy1 (some parsePixel) eol

Now to put it all together, we'll use our conversion function to get our map from the nested lists, and then we've got our two inputs: the DecoderMap and the initial ImageMap!

parseInput :: (MonadLogger m) => ParsecT Void Text m (DecoderMap, ImageMap)
parseInput = do
  decoderMap <- buildDecoder <$> some parsePixel
  eol >> eol
  image <- hashMapFromNestedLists <$> parse2DImage
  return (decoderMap, image)

Processing One Pixel

In terms of writing out the algorithm, we'll try a "bottom up" approach this time. We'll start by solving the smallest problem we can think of, which is this: For a single pixel, how do we calculate its new value in one step of expansion?

There are multiple ways to approach this piece, but the way I chose was to imagine this as a folding function. We'll start a new "enhanced" image as an empty map, and we'll insert the new pixels one-by-one using this folding function. So each iteration modifies a single Coord2 key of an ImageMap. We can fit this into a "fold" pattern if the end of this function's signature looks like this:

-- At some point we have HM.insert coord bit newImage
f :: ImageMap -> Coord2 -> m ImageMap
f newImage coord = ...

But we need some extra information in this function to solve the problem of which "bit" we're inserting. We'll need the original image of course, to find the pixels around this coordinate. We'll also need the decoding map once we convert these to a decimal index. Last of all, we need the "outside bit" discussed above in the solution approach. Here's a type signature to gather these together.

processPixel ::
  (MonadLogger m) =>
  DecoderMap ->
  ImageMap ->
  Bit ->
  ImageMap -> Coord2 -> m ImageMap
processPixel decoderMap initialImage bounds outsideBit newImage pixel = ...

Let's start with a helper function to get the original image's bit at a particular coordinate. Whenever we do a bit lookup outside our original image, its coordinates will not exist in the initialImage map. In this case we'll use the outside bit.

processPixel decoderMap initialImage outsideBit newImage pixel = do
  ...
  where
    getBit :: Coord2 -> Bit
    getBit coord = fromMaybe outsideBit (initialImage HM.!? coord)

Now we need to get all the neighboring coordinates of this pixel. We'll use our getNeighbors8Unbounded utility from above. We could restrict ourselves to the bounds of the original, augmented by 1, but there's no particular need. We get the bit at each location, and assert that we have indeed found all 8 neighbors.

processPixel decoderMap initialImage outsideBit newImage pixel = do
  let allNeighbors = getNeighbors8Unbounded pixel
      neighborBits = getBit <$> allNeighbors
  if length allNeighbors /= 8
    then error "Must have 8 neighbors!"
    ...
where
    getBit = ...

Now the "neighbors" function doesn't include the bit at the specific input pixel! So we have to split our neighbors and insert it into the middle like so:

processPixel decoderMap initialImage outsideBit newImage pixel = do
  let allNeighbors = getNeighbors8Unbounded pixel
      neighborBits = getBit <$> allNeighbors
  if length allNeighbors /= 8
    then error "Must have 8 neighbors!"
    else do
      let (first4, second4) = splitAt 4 neighborBits
          finalBits = first4 ++ (getBit pixel : second4)
     ...
where
    getBit = ...

Now that we have a list of 9 bits, we can decode those bits (using bitsToDecimal64 from last time). This gives us the index to look up in our decoder, which we insert into the new image!

processPixel decoderMap initialImage outsideBit newImage pixel = do
  let allNeighbors = getNeighbors8Unbounded pixel
      neighborBits = getBit <$> allNeighbors
  if length allNeighbors /= 8
    then error "Must have 8 neighbors!"
    else do
      let (first4, second4) = splitAt 4 neighborBits
          finalBits = first4 ++ (getBit pixel : second4)
          indexToDecode = bitsToDecimal64 finalBits
          bit = decoderMap HM.! indexToDecode
      return $ HM.insert pixel bit newImage
  where
    getBit :: Coord2 -> Bit
    getBit coord = fromMaybe outsideBit (initialImage HM.!? coord)

Expanding the Image

Now that we can populate the value for a single pixel, let's step back one layer of the problem and determine how to expand the full image. As mentioned above, we ultimately want to use our function above like a fold. So we need enough arguments to reduce it to:

ImageMap -> Coord2 -> m ImageMap

Then we can start with an empty image map, and loop through every coordinate. So let's make sure we include the decoder map, the original image, and the "outside bit" in our type signature to ensure we have all the processing arguments.

expandImage :: (MonadLogger m) => DecoderMap -> ImageMap -> Bit -> m ImageMap
expandImage decoderMap image outsideBit = ...

Our chief task is to determine the coordinates to loop through. We can't just use the coordinates from the original image though. We have to expand by 1 in each direction so that the outside pixels can come into play. After adding 1, we use Data.Ix.range to interpolate all the coordinates in between our minimum and maximum.

expandImage decoderMap image outsideBit = ...
  where
    (minRow, minCol) = minimum (HM.keys image)
    (maxRow, maxCol) = maximum (HM.keys image)
    newBounds = ((minRow - 1, minCol - 1), (maxRow + 1, maxCol + 1))
    allCoords = range newBounds

And now we have all the ingredients for our fold! We partially apply decoderMap, image, and outsideBit, and then use a fresh empty image and the coordinates.

expandImage decoderMap image outsideBit = foldM
  (processPixel decoderMap image outsideBit)
  HM.empty
  allCoords
  where
    (minRow, minCol) = minimum (HM.keys image)
    (maxRow, maxCol) = maximum (HM.keys image)
    newBounds = ((minRow - 1, minCol - 1), (maxRow + 1, maxCol + 1))
    allCoords = range newBounds

Running the Expansion

Now that we can expand the image once, we just have to zoom out one more layer, and run the expansion a certain number of times. We'll write a recursive function that uses the decoder map, the initial image, and an integer argument for our current step count. This will return the total number of pixels that are lit in the final image.

runExpand :: (MonadLogger m) => DecoderMap -> ImageMap -> Int -> m Int

The base case occurs when we have 0 steps remaining. We'll just count the number of elements that have the One bit in our current image.

runExpand _ image 0 = return $ countWhere (== One) (HM.elems image)

The only trick with the recursive case is that we have to determine the "outside bit". If the element corresponding to 0 in the decoder map is One, then all the outside bits will flip back and forth. So we need to check this bit, as well as the step count. For even step counts, we'll use Zero for the outside bits. And of course, if the decoder head is 0, then there's no flipping at all, so we always get Zero.

runExpand _ image 0 = return $ countWhere (== One) (HM.elems image)
runExpand decoderMap initialImage stepCount = do
  ...
  where
    outsideBit = if decoderMap HM.! 0 == Zero || even stepCount
      then Zero
      else One

Now we have all the arguments we need for our expandImage call! So let's get that new image and recurse using runExpand, with a reduced step count.

runExpand _ image 0 = return $ countWhere (== One) (HM.elems image)
runExpand decoderMap initialImage stepCount = do
  finalImage <- expandImage decoderMap initialImage outsideBit
  runExpand decoderMap finalImage (stepCount - 1)
  where
    outsideBit = if decoderMap HM.! 0 == Zero || even stepCount then Zero else One

Solving the Problem

Now we're well positioned to solve the problem. We'll parse the input into the decoder map and the first image with another old standby, parseFile. Then we'll run the expansion for 2 steps and return the number of lit pixels.

solveDay20Easy :: String -> IO (Maybe Int)
solveDay20Easy fp = runStdoutLoggingT $ do
  (decoderMap, initialImage) <- parseFile parseInput fp
  pixelsLit <- runExpand decoderMap initialImage 2
  return $ Just pixelsLit

The hard part is virtually identical, just increasing the number of steps up to 50.

solveDay20Hard :: String -> IO (Maybe Int)
solveDay20Hard fp = runStdoutLoggingT $ do
  (decoderMap, initialImage) <- parseFile parseInput fp
  pixelsLit <- runExpand decoderMap initialImage 50
  return $ Just pixelsLit

And we're done!

Conclusion

Later this week we'll have the video walkthrough! If you want to see the complete code in action, you can take a look on GitHub.

If you subscribe to our monthly newsletter, you'll get all the latest news and offers from Monday Morning Haskell, as well as access to our subscriber resources!

Read More
James Bowen James Bowen

Binary Packet Parsing

Today we're back with a new problem walkthrough, this time from Day 16 of last year's Advent of Code. In some sense, the parsing section for this problem is very easy - there's not much data to read from the file. In another sense, it's actually rather hard! This problem is about parsing a binary format, similar in some sense to how network packets work. It's a good exercise in handling a few different kinds of recursive cases.

As with the previous parts of this series, you can take a look at the code on GitHub here. This problem also has quite a few utilities, so you can observe those as well. This article is a deep-dive code walkthrough, so having the code handy to look at might be a good idea!

Problem Description

For this problem, we're decoding a binary packet. The packet is initially given as a hexadecimal string.

A0016C880162017C3686B18A3D4780

But we'll turn it into binary and start working strictly with ones and zeros. However, the decoding process gets complicated because the packet is structured in a recursive way. But let's go over some of the rules.

Packet Header

Every packet has a six-bit header. The first three bits give a "version number" for the packet. The next three bits give a "type ID". That part's easy.

Then there are a series of rules about the rest of the information in the packet.

Literals

If the type ID is 4, the packet is a "literal". We then parse the remainder of the packet in 5-bit chunks. The first bit tells us if it is the last chunk of the packet (0 means yes, 1 means there are more chunks). The four other bits in the chunk are used to construct the binary number that forms the "value" of the literal. The more chunks, the higher the number can be.

Operator Sizes

Packets that aren't literals are operators. This means they contain a variable number of subpackets.

Operators have one bit (after the 6-bit header) giving a "length" type. A length type of "1" tells us that the following 11 bits give the number of subpackets. If the length bit is "0", then the next 15 bits give the length of all the subpackets in bits.

The Packet Structure

We'll see how these work out as we parse them. But with this structure in mind, one thing we can immediately do is come up with a recursive data type for a packet. I ended up calling this PacketNode since I thought of each as a node in a tree. It's pretty easy to see how to do this. We start with a base constructor for a Literal packet that only stores the version and the packet value. Then we just add an Operator constructor that will have a list of subpackets as well as a field for the operator type.

data PacketNode =
  Literal Word8 Word64 |
  Operator Word8 Word8 [PacketNode]
  deriving (Show)

Once we've parsed the packet, the "questions to answer" are, for the easy part, to take the sum of all the packet versions in our packet, and then to actually calculate the packet value recursively for the hard part. When we get to that part, we'll see how we use the operators to determine the value.

Solution Approach

The initial "parsing" part of this problem is actually quite easy. But we can observe that even after we have our binary values, it's still a parsing problem! We'll have an easy enough time answering the question once we've parsed our input into a PacketNode. So the core of the problem is parsing the ones and zeros into our PacketNode.

Since this is a parsing problem, we can actually use Megaparsec for the second part, instead of only for getting the input out of the file. Here's a possible signature for our core function:

-- More on this type later
data Bit = One | Zero

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m PacketNode

Whereas we normally use Text as the second type parameter to ParsecT, we can also use any list type, and the library will know what to do! With this function, we'll eventually be able to break our solution into its different parts. But first, we should start with some useful helpers for all our binary parsing.

Binary Utilities

Binary logic comes up fairly often in Advent of Code, and there are quite a few different utilities we would want to use with these ones and zeros. We start with a data type to represent a single bit. For maximum efficiency, we'd want to use a BitVector, but we aren't too worried about that. So we'll make a simple type with two constructors.

data Bit = Zero | One
  deriving (Eq, Ord)

instance Show Bit where
  show Zero = "0"
  show One = "1"

Our first order of business is turning a hexadecimal character into a list of bits. Hexadecimal numbers encapsulate 4 bits. So, for example, 0 should be [Zero, Zero, Zero, Zero], 1 should be [Zero, Zero, Zero, One], and F should be [One, One, One, One]. This is a simple pattern match, but we'll also have a failure case.

parseHexChar :: (MonadLogger m) => Char -> MaybeT m [Bit]
parseHexChar '0' = return [Zero, Zero, Zero, Zero]
parseHexChar '1' = return [Zero, Zero, Zero, One]
parseHexChar '2' = return [Zero, Zero, One, Zero]
parseHexChar '3' = return [Zero, Zero, One, One]
parseHexChar '4' = return [Zero, One, Zero, Zero]
parseHexChar '5' = return [Zero, One, Zero, One]
parseHexChar '6' = return [Zero, One, One, Zero]
parseHexChar '7' = return [Zero, One, One, One]
parseHexChar '8' = return [One, Zero, Zero, Zero]
parseHexChar '9' = return [One, Zero, Zero, One]
parseHexChar 'A' = return [One, Zero, One, Zero]
parseHexChar 'B' = return [One, Zero, One, One]
parseHexChar 'C' = return [One, One, Zero, Zero]
parseHexChar 'D' = return [One, One, Zero, One]
parseHexChar 'E' = return [One, One, One, Zero]
parseHexChar 'F' = return [One, One, One, One]
parseHexChar c = logErrorN ("Invalid Hex Char: " <> pack [c]) >> mzero

If we wanted, we could also include lowercase, but this problem doesn't require it.

We also want to be able to turn a list of bits into a decimal number. We'll do this for a couple different sizes of numbers. For smaller numbers (8 bits or below), we might want to return a Word8. For larger numbers we can do Word64. Calculating the decimal number is a tail recursive process, where we track the accumulated sum and the current power of 2.

bitsToDecimal8 :: [Bit] -> Word8
bitsToDecimal8 bits = if length bits > 8
  then error ("Too long! Use bitsToDecimal64! " ++ show bits)
  else btd8 0 1 (reverse bits)
    where
      btd8 :: Word8 -> Word8 -> [Bit] -> Word8
      btd8 accum _ [] = accum
      btd8 accum mult (b : rest) = case b of
        Zero -> btd8 accum (mult * 2) rest
        One -> btd8 (accum + mult) (mult * 2) rest

bitsToDecimal64 :: [Bit] -> Word64
bitsToDecimal64 bits = if length bits > 64
  then error ("Too long! Use bitsToDecimalInteger! " ++ (show $ bits))
  else btd64 0 1 (reverse bits)
    where
      btd64 :: Word64 -> Word64 -> [Bit] -> Word64
      btd64 accum _ [] = accum
      btd64 accum mult (b : rest) = case b of
        Zero -> btd64 accum (mult * 2) rest
        One -> btd64 (accum + mult) (mult * 2) rest

Last of all, we should write a parser for reading a hexadecimal string from our file. This is easy, because Megaparsec already has a parser for a single hexadecimal character.

parseHexadecimal :: (MonadLogger m) => ParsecT Void Text m String
parseHexadecimal = some hexDigitChar

Basic Bit Parsing

With all these utilities in place, we can get started with parsing our list of bits. As mentioned above, we want a function that generally looks like this:

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m PacketNode

However, we need one extra nuance. Because we have one layer that will parse several consecutive packets based on the number of bits parsed, we should also return this number as part of our function. In this way, we'll be able to determine if we're done with the subpackets of an operator packet.

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m (PacketNode, Word64)

We'll also want a wrapper around this function so we can call it from a normal context with the list of bits as the input. This looks a lot like the existing utilities (e.g. for parsing a whole file). We use runParserT from Megaparsec and do a case-branch on the result.

parseBits :: (MonadLogger m) => [Bit] -> MaybeT m PacketNode
parseBits bits = do
  result <- runParserT parsePacketNode "Utils.hs" bits
  case result of
    Left e -> logErrorN ("Failed to parse: " <> (pack . show $ e)) >> mzero
    Right (packet, _) -> return packet

We ignore the "size" of the parsed packet in the primary case, but we'll use its result in the recursive calls to parsePacketNode!

Having done this, we can now start writing basic parser functions. To parse a single bit, we'll just wrap the anySingle combinator from Megaparsec.

parseBit :: ParsecT Void [Bit] m Bit
parseBit = anySingle

If we want to parse a certain number of bits, we'll want to use the monadic count combinator. Let's write a function that parses three bits and turns it into a Word8, since we'll need this for the packet version and type ID.

parse3Bit :: ParsecT Void [Bit] m Word8
parse3Bit = bitsToDecimal8 <$> count 3 parseBit

We can then immediately use this to start filling in our parsing function!

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m (PacketNode, Word64)
parsePacketNode = do
  packetVersion <- parse3Bit
  packetTypeId <- parse3Bit
  ...

Then the rest of the function will depend upon the different cases we might parse.

Parsing a Literal

We can start with the "literal" case. This parses the "value" contained within the packet. We need to track the number of bits we parse so we can use this result in our parent function!

parseLiteral :: ParsecT Void [Bit] m (Word64, Word64)

As explained above, we examine chunks 5 bits at a time, and we end the packet once we have a chunk that starts with 0. This is a "while" loop pattern, which suggests tail recursion as our solution!

We'll have two accumulator arguments. First, the series of bits that contribute to our literal value. Second, the number of bits we've parsed so far (which must include the signal bit).

parseLiteral :: ParsecT Void [Bit] m (Word64, Word64)
parseLiteral = parseLiteralTail [] 0
  where
    parseLiteralTail :: [Bit] -> Word64 -> ParsecT Void [Bit] m (Word64, Word64)
    parseLiteralTail accumBits numBits = do
      ...

First, we'll parse the leading bit, followed by the four bits in the chunk value. We append these to our previously accumulated bits, and add 5 to the number of bits parsed:

parseLiteral :: ParsecT Void [Bit] m (Word64, Word64)
parseLiteral = parseLiteralTail [] 0
  where
    parseLiteralTail :: [Bit] -> Word64 -> ParsecT Void [Bit] m (Word64, Word64)
    parseLiteralTail accumBits numBits = do
      leadingBit <- parseBit
      nextBits <- count 4 parseBit
      let accum' = accumBits ++ nextBits
      let numBits' = numBits + 5
      ...

If the leading bit is 0, we're done! We can return our value by converting our accumulated bits to decimal. Otherwise, we recurse with our new values.

parseLiteral :: ParsecT Void [Bit] m (Word64, Word64)
parseLiteral = parseLiteralTail [] 0
  where
    parseLiteralTail :: [Bit] -> Word64 -> ParsecT Void [Bit] m (Word64, Word64)
    parseLiteralTail accumBits numBits = do
      leadingBit <- parseBit
      nextBits <- count 4 parseBit
      let accum' = accumBits ++ nextBits
      let numBits' = numBits + 5
      if leadingBit == Zero
        then return (bitsToDecimal64 accum', numBits')
        else parseLiteralTail accum' numBits'

Then it's very easy to incorporate this into our primary function. We check the type ID, and if it's "4" (for a literal), we call this function, and return with the Literal packet constructor.

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m (PacketNode, Word64)
parsePacketNode = do
  packetVersion <- parse3Bit
  packetTypeId <- parse3Bit
  if packetTypeId == 4
    then do
      (literalValue, literalBits) <- parseLiteral
      return (Literal packetVersion literalValue, literalBits + 6)
    else
      ...

Now we need to consider the "operator" cases and their subpackets.

Parsing from Number of Packets

We'll start with the simpler of these two cases, which is when we are parsing a specific number of subpackets. The first step, of course, is to parse the length type bit.

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m (PacketNode, Word64)
parsePacketNode = do
  packetVersion <- parse3Bit
  packetTypeId <- parse3Bit
  if packetTypeId == 4
    then do
      (literalValue, literalBits) <- parseLiteral
      return (Literal packetVersion literalValue, literalBits + 6)
    else do
      lengthTypeId <- parseBit
      if lengthTypeId == One
        then do
        ...

First, we have to count out 11 bits and use that to determine the number of subpackets. Once we have this number, we just have to recursively call the parsePacketNode function the given number of times.

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m (PacketNode, Word64)
parsePacketNode = do
 ...
  if packetTypeId == 4
    then ...
    else do
      lengthTypeId <- parseBit
      if lengthTypeId == One
        then do
          numberOfSubpackets <- bitsToDecimal64 <$> count 11 parseBit
          subPacketsWithLengths <- replicateM (fromIntegral numberOfSubpackets) parsePacketNode
         ...

We'll unzip these results to get our list of packets and the lengths. To get our final packet length, we take the sum of the sizes, but we can't forget to add the header bits and the length type bit (7 bits), and the bits from the number of subpackets (11).

parsePacketNode :: (MonadLogger m) => ParsecT Void [Bit] m (PacketNode, Word64)
parsePacketNode = do
 ...
  if packetTypeId == 4
    then ...
    else do
      lengthTypeId <- parseBit
      if lengthTypeId == One
        then do
          numberOfSubpackets <- bitsToDecimal64 <$> count 11 parseBit
          subPacketsWithLengths <- replicateM (fromIntegral numberOfSubpackets) parsePacketNode
         let (subPackets, lengths) = unzip subPacketsWithLengths
          return (Operator packetVersion packetTypeId subPackets, sum lengths + 7 + 11)
        else

Parsing from Number of Bits

Parsing based on the number of bits in all the subpackets is a little more complicated, because we have more state to track. As we loop through the different subpackets, we need to keep track of how many bits we still have to parse. So we'll make a separate helper function.

parseForPacketLength :: (MonadLogger m) => Int -> Word64 -> [PacketNode] -> ParsecT Void [Bit] m ([PacketNode], Word64)
parseForPacketLength remainingBits accumBits prevPackets = ...

The base case comes when we have 0 bits remaining. Ideally, this occurs with exactly 0 bits. If it's a negative number, this is a problem. But if it's successful, we'll reverse the accumulated packets and return the number of bits we've accumulated.

parseForPacketLength :: (MonadLogger m) => Int -> Word64 -> [PacketNode] -> ParsecT Void [Bit] m ([PacketNode], Word64)
parseForPacketLength remainingBits accumBits prevPackets = if remainingBits <= 0
  then do
    if remainingBits < 0
      then error "Failing"
      else return (reverse prevPackets, accumBits)
  else ...

In the recursive case, we make one new call to parsePacketNode (the original function, not this helper). This gives us a new packet, and some more bits that we've parsed (this is why we've been tracking that number the whole time). So we can subtract the size from the remaining bits, and add it to the accumulated bits. And then we'll make the actual recursive call to this helper function.

parseForPacketLength :: (MonadLogger m) => Int -> Word64 -> [PacketNode] -> ParsecT Void [Bit] m ([PacketNode], Word64)
parseForPacketLength remainingBits accumBits prevPackets = if remainingBits <= 0
  then do
    if remainingBits < 0
      then error "Failing"
      else return (reverse prevPackets, accumBits)
  else do
    (newPacket, size) <- parsePacketNode
    parseForPacketLength (remainingBits - fromIntegral size) (accumBits + fromIntegral size) (newPacket : prevPackets)

And that's all! All our different pieces fit together now and we're able to parse our packet!

Solving the Problems

Now that we've parsed the packet into our structure, the rest of the problem is actually quite easy and fun! We've created a straightforward recursive structure, and so we can loop through it in a straightforward recursive way. We'll just always use the Literal as the base case, and then loop through the list of packets for the base case.

Let's start with summing the packet versions. This will return a Word64 since we could be adding a lot of package versions. With a Literal package, we just immediately return the version.

sumPacketVersions :: PacketNode -> Word64
sumPacketVersions (Literal v _) = fromIntegral v
...

Then with operator packets, we just map over the sub-packets, take the sum of their versions, and then add the original packet's version.

sumPacketVersions :: PacketNode -> Word64
sumPacketVersions (Literal v _) = fromIntegral v
sumPacketVersions (Operator v _ packets) = fromIntegral v +
  sum (map sumPacketVersions packets)

Now, for calculating the final packet value, we again start with the Literal case, since we'll just return its value. Note that we'll do this monadically, since we'll have some failure conditions in the later parts.

calculatePacketValue :: MonadLogger m => PacketNode -> MaybeT m Word64
calculatePacketValue (Literal _ x) = return x

Now, for the first time in the problem, we actually have to care what the operators mean! Here's a summary of the first few operators:

0 = Sum of all subpackets
1 = Product of all subpackets
2 = Minimum of all subpackets
3 = Maximum of all subpackets

There are three other operators following the same basic pattern. They expect exactly two subpackets and perform a binary, boolean operator. If it is true, the value is 1. If the operation is false, the packet value is 0.

5 = Greater than operator (<)
6 = Less than operator (>)
7 = Equals operator (==)

For the first set of operations, we can recursively calculate the value of the sub-packets, and take the appropriate aggregate function over the list.

calculatePacketValue :: MonadLogger m => PacketNode -> MaybeT m Word64
calculatePacketValue (Literal _ x) = return x
calculatePacketValue (Operator _ 0 packets) = sum <$> mapM calculatePacketValue packets
calculatePacketValue (Operator _ 1 packets) = product <$> mapM calculatePacketValue packets
calculatePacketValue (Operator _ 2 packets) = minimum <$> mapM calculatePacketValue packets
calculatePacketValue (Operator _ 3 packets) = maximum <$> mapM calculatePacketValue packets
...

For the binary operations, we first have to verify that there are only two packets.

calculatePacketValue :: MonadLogger m => PacketNode -> MaybeT m Word64
...
calculatePacketValue (Operator _ 5 packets) = do
  if length packets /= 2
    then logErrorN "> operator '5' must have two packets!" >> mzero
    else ...

Then we just de-structure the packets, calculate each value, compare them, and then return the appropriate value.

calculatePacketValue :: MonadLogger m => PacketNode -> MaybeT m Word64
...
calculatePacketValue (Operator _ 5 packets) = do
  if length packets /= 2
    then logErrorN "> operator '5' must have two packets!" >> mzero
    else do
      let [p1, p2] = packets
      v1 <- calculatePacketValue p1
      v2 <- calculatePacketValue p2
      return (if v1 > v2 then 1 else 0)
calculatePacketValue (Operator _ 6 packets) = do
  if length packets /= 2
    then logErrorN "< operator '6' must have two packets!" >> mzero
    else do
      let [p1, p2] = packets
      v1 <- calculatePacketValue p1
      v2 <- calculatePacketValue p2
      return (if v1 < v2 then 1 else 0)
calculatePacketValue (Operator _ 7 packets) = do
  if length packets /= 2
    then logErrorN "== operator '7' must have two packets!" >> mzero
    else do
      let [p1, p2] = packets
      v1 <- calculatePacketValue p1
      v2 <- calculatePacketValue p2
      return (if v1 == v2 then 1 else 0)
calculatePacketValue p = do
  logErrorN ("Invalid packet! " <> (pack . show $ p))
  mzero

Concluding Code

To tie everything together, we just follow the steps.

  1. Parse the hexadecimal from the file
  2. Transform the hexadecimal string into a list of bits
  3. Parse the packet
  4. Answer the question

For the first part, we use sumPacketVersions on the resulting packet.

solveDay16Easy :: String -> IO (Maybe Int)
solveDay16Easy fp = runStdoutLoggingT $ do
  hexLine <- parseFile parseHexadecimal fp
  result <- runMaybeT $ do
    bitLine <- concatMapM parseHexChar hexLine
    packet <- parseBits bitLine
    return $ sumPacketVersions packet
  return (fromIntegral <$> result)

And the "hard" solution is the same, except we use calculatePacketValue instead.

solveDay16Hard :: String -> IO (Maybe Int)
solveDay16Hard fp = runStdoutLoggingT $ do
  hexLine <- parseFile parseHexadecimal fp
  result <- runMaybeT $ do
    bitLine <- concatMapM parseHexChar hexLine
    packet <- parseBits bitLine
    calculatePacketValue packet
  return (fromIntegral <$> result)

And we're done!

Conclusion

That's all for this solution! As always, you can take a look at the code on GitHub. Later this week I'll have the video walkthrough as well. To keep up with all the latest news, make sure to subscribe to our monthly newsletter! Subscribing will give you access to our subscriber resources, like our Beginners Checklist and our Production Checklist.

Read More
James Bowen James Bowen

Polymer Expansion

Today we're back with another Advent of Code walkthrough. We're doing the problem from Day 14 of last year. Here are a couple previous walkthroughs:

  1. Day 8 (Seven Segment Display)
  2. Day 11 (Octopus Energy Levels)

If you want to see the code for today, you can find it here on GitHub!

If you're enjoying these problem overviews, make sure to subscribe to our monthly newsletter!

Problem Statement

The subject of today's problem is "polymer expansion". What this means in programming terms is that we'll be taking a string and inserting new characters into it based on side-by-side pairs.

The puzzle input looks like this:

NNCB

NN -> C
NC -> B
CB -> H
...

The top line of the input is our "starter string". It's our base for expansion. The lines that follow are codes that explain how to expand each pair of characters.

So in our original string of four characters (NNCB), there are three pairs: NN, NC, and CB. With the exception of the start and end characters, each character appears in two different pairs. So for each pair, we find the corresponding "insertion character" and construct a new string where all the insertion characters come between their parent pairs. The first pair gives us a C, the second pair gives us a new B, and the third pair gets us a new H.

So our string for the second step becomes: NCNBCHB. We'll then repeat the expansion a certain number of times.

For the first part, we'll run 10 steps of the expansion algorithm. For the second part, we'll do 40 steps. Each time, our final answer comes from taking the number of occurrences of the most common letter in the final string, and subtracting the occurrences of the least common letter.

Utilities

The main utility we'll end up using for this problem is an occurrence map. I decided to make this general idea for counting the number of occurrences of some item, since it's such a common pattern in these puzzles. The most generic alias we could have is a map where the key and value are parameterized, though the expectation is that i is an Integral type:

type OccMapI a i = Map a i

The most common usage is counting items up from 0. Since this is an unsigned, non-negative number, we would use Word.

type OccMap a = Map a Word

However, for today's problem, we're gonna be dealing with big numbers! So just to be safe, we'll use the unbounded Integer type, and make a separate type definition for that.

type OccMapBig a = Map a Integer

We can make a couple useful helper functions for this occurrence map. First, we can add a certain number value to a key.

addKey :: (Ord a, Integral i) => OccMapI a i -> a -> i -> OccMapI a i
addKey prevMap key count = case M.lookup key prevMap of
    Nothing -> M.insert key count prevMap
    Just x -> M.insert key (x + count) prevMap

We can add a specialization of this for "incrementing" a key, adding 1 to its value. We won't use this for today's solution, but it helps in a lot of cases.

incKey :: (Ord a, Integral i) => OccMapI a i -> a -> OccMapI a i
incKey prevMap key = addKey prevMap key 1

Now with our utilities out of the way, let's start parsing our input!

Parsing the Input

First off, let's define the result types of our parsing process. The starter string comes on the first line, so that's a separate String. But then we need to create a mapping between character pairs and the resulting character. We'll eventually want these in a HashMap, so let's make a type alias for that.

type PairMap = HashMap (Char, Char) Char

Now for parsing, we need to parse the start string, an empty line, and then each line of the code mapping.

Since most of the input is in the code mapping lines, let's do that first. Each line consists of parsing three characters, just separated by the arrow. This is very straightforward with Megaparsec.

parsePairCode :: (MonadLogger m) => ParsecT Void Text m (Char, Char, Char)
parsePairCode = do
  input1 <- letterChar
  input2 <- letterChar
  string " -> "
  output <- letterChar
  return (input1, input2, output)

Now let's make a function to combine these character tuples into the map. This is a nice quick fold:

buildPairMap :: [(Char, Char, Char)] -> HashMap (Char, Char) Char
buildPairMap = foldl (\prevMap (c1, c2, c3) -> HM.insert (c1, c2) c3 prevMap) HM.empty

The rest of our parsing function then parses the starter string and a couple newline characters before we get our pair codes.

parseInput :: (MonadLogger m) => ParsecT Void Text m (String, PairMap)
parseInput = do
  starterCode <- some letterChar
  eol >> eol
  pairCodes <- sepEndBy1 parsePairCode eol
  return (starterCode, buildPairMap pairCodes)

Then it will be easy enough to use our parseFile function from previous days. Now let's figure out our solution approach.

A Naive Approach

Now at first, the polymer expansion seems like a fairly simple problem. The root of the issue is that we have to write a function to run one step of the expansion. In principle, this isn't a hard function. We loop through the original string, two letters at a time, and gradually construct the new string for the next step.

One way to handle this would be with a tail recursive helper function. We could accumulate the new string (in reverse) through an accumulator argument.

runExpand :: (MonadLogger m)
  => PairMap
  -> String -- Accumulator
  -> String -- Remaining String
  -> m String

The "base case" of this function is when we have only one character left. In this case, we append it to the accumulator and reverse it all.

runExpand :: (MonadLogger m) => PairMap -> String -> String -> m String
runExpand pairMap accum [lastChar] = return $ reverse (lastChar : accum)

For the recursive case, we imagine we have at least two characters remaining. We'll look these characters up in our map. Then we'll append the first character and the new character to our accumulator, and then recurse on the remainder (including the second character).

runExpand :: (MonadLogger m) => PairMap -> String -> String -> m String
runExpand _ accum [lastChar] = return $ reverse (lastChar : accum)
runExpand pairMap accum (firstChar: secondChar : rest) = do
  let insertChar = pairMap HM.! (nextChar, secondChar)
  runExpand pairMap (insertChar : firstChar : accum) (secondChar : rest)

There are some extra edge cases we could handle here, but this isn't going to be how we solve the problem. The approach works...in theory. In practice though, it only works for a small number of steps. Why? Well the problem description gives a hint: This polymer grows quickly. In fact, with each step, our string essentially doubles in size - exponential growth!

This sort of solution is good enough for the first part, running only 10 steps. However, as the string gets bigger and bigger, we'll run out of memory! So we need something more efficient.

A Better Approach

The key insight here is that we don't actually care about the order of the letters in the string at any given time. All we really need to think about is the number of each kind of pair that is present. How does this work?

Well recall some of our basic code pairs from the top:

NN -> C
NC -> B
CB -> H
BN -> B

With the starter string like NNCB, we have one NN pair, an NC pair, and CB pair. In the next step, the NN pair generates two new pairs. Because a C is inserted between the N, we lose the NN pair but gain a NC pair and a CN pair. So after expansion the number of resulting NC pairs is 1, and the number of CN pairs is 1.

However, this is true of every NN pair within our string! Suppose we instead start off this with:

NNCBNN

Now there are two NN pairs, meaning the resulting string will have two NC pairs and two CN pairs, as you can see by taking a closer look at the result: NCNBCHBBNCN.

So instead of keeping the complete string in memory, all we need to do is use the "occurrence map" utility to store the number of each pair for our current state. So we'll keep folding over an object of type OccMapBig (Char, Char).

The first step of our solution then is to construct our initial mapping from the starter code. We can do this by folding through the starter string in a similar way to the example code in the naive solution. We one or zero characters are left in our "remainder", that's a base case and we can return the map.

-- Same signature as naive approach
expandPolymerLong :: (MonadLogger m) => Int -> String -> PairMap -> m (Maybe Integer)
expandPolymerLong numSteps starterCode pairMap = do
  let starterMap = buildInitialMap M.empty starterCode
  ...
  where
    buildInitialMap :: OccMapBig (Char, Char) -> String -> OccMapBig (Char, Char)
    buildInitialMap prevMap "" = prevMap
    buildInitialMap prevMap [_] = prevMap
   ...

Now for the recursive case, we have at least two characters remaining, so we'll just increment the value for the key formed by these characters!

-- Same signature as naive approach
expandPolymerLong :: (MonadLogger m) => Int -> String -> PairMap -> m (Maybe Integer)
expandPolymerLong numSteps starterCode pairMap = do
  let starterMap = buildInitialMap M.empty starterCode
  ...
  where
    buildInitialMap :: OccMapBig (Char, Char) -> String -> OccMapBig (Char, Char)
    buildInitialMap prevMap "" = prevMap
    buildInitialMap prevMap [_] = prevMap
    buildInitialMap prevMap (firstChar : secondChar : rest) = buildInitialMap (incKey prevMap (firstChar, secondChar)) (secondChar : rest)

The key point, of course, is how to expand our map each step, so let's do this next!

A New Expansion

To run a single step in our naive solution, we could use a tail-recursive helper to gradually build up the new string (the "accumulator") from the old string (the "remainder" or "rest"). So our type signature looked like this:

runExpand :: (MonadLogger m)
  => PairMap
  -> String -- Accumulator
  -> String -- Remainder
  -> m String

For our new expansion step, we're instead taking one occurrence map and transforming it into a new occurrence map. For convenience, we'll include an integer argument keeping track of which step we're on, but we won't need to use it in the function. We'll do all this within expandPolymerLong so that we have access to the PairMap argument.

expandPolymerLong :: (MonadLogger m) => Int -> String -> PairMap -> m (Maybe Integer)
expandPolymerLong numSteps starterCode pairMap = do
  ...
  where
    runStep ::(MonadLogger m) => OccMapBig (Char, Char) -> Int -> m (OccMapBig (Char, Char))
    runStep = ...

The runStep function has a simple idea behind it though. We gradually reconstruct our occurrence map by folding through the pairs in the previous map. We'll make a new function runExpand to act as the folding function.

expandPolymerLong :: (MonadLogger m) => Int -> String -> PairMap -> m (Maybe Integer)
expandPolymerLong numSteps starterCode pairMap = do
  ...
  where
    runStep ::(MonadLogger m) => OccMapBig (Char, Char) -> Int -> m (OccMapBig (Char, Char))
    runStep prevMap _ = foldM runExpand M.empty (M.toList prevMap)

    runExpand :: (MonadLogger m) => OccMapBig (Char, Char) -> ((Char, Char), Integer) -> m (OccMapBig (Char, Char))
    runExpand = ...

For this function, we begin by looking up the two-character code in our map. If for whatever reason it doesn't exist, we'll move on, but it's worth logging an error message since this isn't supposed to happen.

runExpand :: (MonadLogger m) => OccMapBig (Char, Char) -> ((Char, Char), Integer) -> m (OccMapBig (Char, Char))
runExpand prevMap (code@(c1, c2), count) = case HM.lookup code pairMap of
  Nothing -> logErrorN ("Missing Code: " <> pack [c1, c2]) >> return prevMap
  Just newChar -> ...

Now once we've found the new character, we'll create our first new pair and our second new pair by inserting the new character with our previous characters.

runExpand :: (MonadLogger m) => OccMapBig (Char, Char) -> ((Char, Char), Integer) -> m (OccMapBig (Char, Char))
runExpand prevMap (code@(c1, c2), count) = case HM.lookup code pairMap of
  Nothing -> logErrorN ("Missing Code: " <> pack [c1, c2]) >> return prevMap
  Just newChar -> do
    let first = (c1, newChar)
        second = (newChar, c2)
  ...

And to wrap things up, we add the new count value for each of our new keys to the existing map! This is done with nested calls to addKey on our occurrence map.

runExpand :: (MonadLogger m) => OccMapBig (Char, Char) -> ((Char, Char), Integer) -> m (OccMapBig (Char, Char))
runExpand prevMap (code@(c1, c2), count) = case HM.lookup code pairMap of
  Nothing -> logErrorN ("Missing Code: " <> pack [c1, c2]) >> return prevMap
  Just newChar -> do
    let first = (c1, newChar)
        second = (newChar, c2)
  return $ addKey (addKey prevMap first count) second count

Rounding Up

Now we have our last task: finding the counts of the characters in the final string, and subtracting the minimum from the maximum. This requires us to first disassemble our mapping of pair counts into a mapping of individual character counts. This is another fold step. But just like before, we use nested calls to addKey on an occurrence map! See how countChars works below:

expandPolymerLong :: (MonadLogger m) => Int -> String -> PairMap -> m (Maybe Integer)
expandPolymerLong numSteps starterCode pairMap = do
  let starterMap = buildInitialMap M.empty starterCode
  finalOccMap <- foldM runStep starterMap [1..numSteps]
  let finalCharCountMap = foldl countChars M.empty (M.toList finalOccMap)
  ...
  where
    countChars :: OccMapBig Char -> ((Char, Char), Integer) -> OccMapBig Char
    countChars prevMap ((c1, c2), count) = addKey (addKey prevMap c1 count) c2 count

So we have a count of the characters in our final string...sort of. Recall that we added characters for each pair. Thus the number we're getting is basically doubled! So we want to divide each value by 2, with the exception of the first and last characters in the string. If these are the same, we have an edge case. We divide the number by 2 and then add an extra one. Otherwise, if a character has an odd value, it must be on the end, so we divide by two and round up. We sum up this logic with the quotRoundUp function, which we apply over our finalCharCountMap.

expandPolymerLong :: (MonadLogger m) => Int -> String -> PairMap -> m (Maybe Integer) expandPolymerLong numSteps starterCode pairMap = do let starterMap = buildInitialMap M.empty starterCode finalOccMap <- foldM runStep starterMap [1..numSteps] let finalCharCountMap = foldl countChars M.empty (M.toList finalOccMap) let finalCounts = map quotRoundUp (M.toList finalCharCountMap) ... where quotRoundUp :: (Char, Integer) -> Integer quotRoundUp (c, i) = if even i then quot i 2 + if head starterCode == c && last starterCode == c then 1 else 0 else quot i 2 + 1

And finally, we consider the list of outcomes and take the maximum minus the minimum!

```haskell
expandPolymerLong :: (MonadLogger m) => Int -> String -> PairMap -> m (Maybe Integer)
expandPolymerLong numSteps starterCode pairMap = do
  let starterMap = buildInitialMap M.empty starterCode
  finalOccMap <- foldM runStep starterMap [1..numSteps]
  let finalCharCountMap = foldl countChars M.empty (M.toList finalOccMap)
  let finalCounts = map quotRoundUp (M.toList finalCharCountMap)
  if null finalCounts
    then logErrorN "Final Occurrence Map is empty!" >> return Nothing
    else return $ Just $ fromIntegral (maximum finalCounts - minimum finalCounts)

  where
    buildInitialMap = ...
    runStep = ...
    runExpand = ...
    countChars = ...
    quotRoundUp = ...

Last of all, we combine input parsing with solving the problem. Our "easy" and "hard" solutions look the same, just with different numbers of steps.

solveDay14Easy :: String -> IO (Maybe Integer)
solveDay14Easy fp = runStdoutLoggingT $ do
  (starterCode, pairCodes) <- parseFile parseInput fp
  expandPolymerLong 10 starterCode pairCodes

solveDay14Hard :: String -> IO (Maybe Integer)
solveDay14Hard fp = runStdoutLoggingT $ do
  (starterCode, pairCodes) <- parseFile parseInput fp
  expandPolymerLong 40 starterCode pairCodes

Conclusion

Hopefully that solution makes sense to you! In case I left anything out of my solution, you can peruse the code on GitHub. Later this week, we'll have a video walkthrough of this solution!

If you're enjoying this content, make sure to subscribe to our monthly newsletter, which will also give you access to our Subscriber Resources!

Read More
James Bowen James Bowen

Flashing Octopuses and BFS

Today we continue our new series on Advent of Code solutions from 2021. Last time we solved the seven-segment logic puzzle. Today, we'll look at the Day 11 problem which focuses a bit more on traditional coding structures and algorithms.

This will be another in-depth coding write-up. For the next week or so after this I'll switch to doing video reviews so you can compare the styles. I haven't been too exhaustive with listing imports in these examples though, so if you're curious about those you can take a look at the full solution here on GitHub. So now, let's get started!

Problem Statement

For this problem, we're dealing with a set of Octopuses, (Advent of Code had an aquatic theme last year) and these octopuses apparently have an "energy level" and eventually "flash" when they reach their maximum energy level. They sit nicely in a 2D grid for us, and the puzzle input is just a grid of single-digit integers for their initial "energy level". Here's an example.

5483143223
2745854711
5264556173
6141336146
6357385478
4167524645
2176841721
6882881134
4846848554
5283751526

Now, as time goes by, their energy levels increase. With each step, all energy levels go up by one. So after a single step, the energy grid looks like this:

6594254334
3856965822
6375667284
7252447257
7468496589
5278635756
3287952832
7993992245
5957959665
6394862637

However, when an octopus reaches level 10, it flashes. This has two results for the next step. First, its own energy level always reverts to 0. Second, it increments the energy level of all neighbors as well. This, of course, can make things more complicated, because we can end up with a cascading series of flashes. Even an octopus that has a very low energy level at the start of a step can end up flashing. Here's an example.

Start:
11111
19891
18181
19891
11111

End:
34543
40004
50005
40004
34543

The 1 in the center still ends up flashing. It has four neighbors as 9 which all flash. The surrounding 8's then flash because each has two 9 neighbors. As a result, the 1 has 8 neighbors flashing. Combining with its own increment, it becomes as 10, so it also flashes.

The good news is that all flashing octopuses revert to 0. They don't start counting again from other adjacent flashes so we can't get an infinite loop of flashing and we don't have to worry about the "order" of flashing.

For the first part of the problem, we have to find the total number of flashes after a certain number of steps. For the second part, we have to find the first step when all of the octopuses flash.

Solution Approach

There's nothing too difficult about the solution approach here. Incrementing the grid and finding the initial flashes are easy problems. The only tricky part is cascading the flashes. For this, we need a Breadth-First-Search where each item in the queue is a flash to resolve. As long as we're careful in our accounting and in the update step, we should be able to answer the questions fairly easily.

Utilities

As with last time, we'll start the coding portion with a few utilities that will (hopefully) end up being useful for other problems. The first of these is a simple one. We'll use a type synonym Coord2 to represent a 2D integer coordinate.

type Coord2 = (Int, Int)

Next, we'll want another general parsing function. Last time, we covered parseLinesFromFile, which took a general parser and applied it to every line of an input file. But we also might want to incorporate the "line-by-line" behavior into our general parser, so we'll add a function to parse the whole file given a single ParsecT expression. The structure is much the same, it just does even less work than our prior example.

parseFile :: (MonadIO m) => ParsecT Void Text m a -> FilePath -> m a
parseFile parser filepath = do
  input <- pack <$> liftIO (readFile filepath)
  result <- runParserT parser "Utils.hs" input
  case result of
    Left e -> error $ "Failed to parse: " ++ show e
    Right x -> return x

Last of all, this problem deals with 2D grids and spreading out the "effect" of one square over all eight of its neighbors. So let's write a function to get all the adjacent coordinates of a tile. We'll call this neighbors8, and it will be very similar to a function getting neighbors in 4 directions that I used in this Dijkstra's algorithm implementation.

getNeighbors8 :: HashMap Coord2 a -> Coord2 -> [Coord2]
getNeighbors8 grid (row, col) = catMaybes
  [maybeUp, maybeUpRight, maybeRight, maybeDownRight, maybeDown, maybeDownLeft, maybeLeft, maybeUpLeft]
  where
    (maxRow, maxCol) = maximum $ HM.keys grid
    maybeUp = if row > 0 then Just (row - 1, col) else Nothing
    maybeUpRight = if row > 0 && col < maxCol then Just (row - 1, col + 1) else Nothing
    maybeRight = if col < maxCol then Just (row, col + 1) else Nothing
    maybeDownRight = if row < maxRow && col < maxCol then Just (row + 1, col + 1) else Nothing
    maybeDown = if row < maxRow then Just (row + 1, col) else Nothing
    maybeDownLeft = if row < maxRow && col > 0 then Just (row + 1, col - 1) else Nothing
    maybeLeft = if col > 0 then Just (row, col - 1) else Nothing
    maybeUpLeft = if row > 0 && col > 0 then Just (row - 1, col - 1) else Nothing

This function could also apply to an Array instead of a Hash Map. In fact, it might be even more appropriate there. But below we'll get into the reasons for using a Hash Map.

Parsing the Input

Now, let's get to the first step of the problem itself, which is to parse the input. In this case, the input is simply a 2D array of single-digit integers, so this is a fairly straightforward process. In fact, I figured this whole function could be re-used as well, so it could also be considered a utility.

The first step is to parse a line of integers. Since there are no spaces and no separators, this is very simple using some.

import Data.Char (digitToInt)
import Text.Megaparsec (some)

parseDigitLine :: ParsecT Void Text m [Int]
parseDigitLine = fmap digitToInt <$> some digitChar

Now getting a repeated set of these "integer lists" over a series of lines uses the same trick we saw last time. We use sepEndBy1 combined with the eol parser for end-of-line.

parse2DDigits :: (Monad m) => ParsecT Void Text m [[Int]]
parse2DDigits = sepEndBy1 parseDigitLine eol

However, we want to go one step further. A list-of-lists-of-ints is a cumbersome data structure. We can't really update it efficiently. Nor, in fact, can we even access 2D indices quickly. There are two good structures for us to use, depending on the problem. We can either use a 2D array, or a HashMap where the keys are 2D coordinates.

Because we'll be updating the structure itself, we want a Hash Map in this case. Haskell's Array structure has no good way to update its values without a full copy. If the structure were read only though, Array would be the better choice. For our current problem, the mutable array pattern would also be an option. But for now I'll keep things simpler.

So we need a function to convert nested integer lists into a Hash Map with coordinates. The first step in this process is to match each list of integers with a row number, and each integer within the list with its column number. Infinite lists, ranges and zip are excellent tools here!

hashMapFromNestedLists :: [[Int]] -> HashMap Coord2 Int
hashMapFromNestedLists inputs = ...
  where
    x = zip [0,1..] (map (zip [0,1..]) inputs)

Now in most languages, we would use a nested for-loop. The outer structure would cover the rows, the inner structure would cover the columns. In Haskell, we'll instead do a 2-level fold. The outer layer (the function f) will cover the rows. The inner layer (function g) will cover the columns. Each step updates the Hash Map appropriately.

hashMapFromNestedLists :: [[Int]] -> HashMap Coord2 Int
hashMapFromNestedLists inputs = foldl f HM.empty x
  where
    x = zip [0,1..] (map (zip [0,1..]) inputs)

    f :: HashMap Coord2 Int -> (Int, [(Int, Int)]) -> HashMap Coord2 Int
    f prevMap (row, pairs) = foldl (g row) prevMap pairs

    g :: Int -> HashMap Coord2 Int -> Coord2 -> HashMap Coord2 Int
    g row prevMap (col, val) = HM.insert (row, col) val prevMap

And now we can pull it all together and parse our input!

solveDay11Easy :: String -> IO (Maybe Int)
solveDay11Easy fp = do
  initialState <- parseFile parse2DDigitHashMap fp
  ...

solveDay11Hard :: String -> IO (Maybe Int)
solveDay11Hard fp = do
  initialState <- parseFile parse2DDigitHashMap fp
  ...

Basic Step Running

Now let's get to the core of the algorithm. The function we really need to get right here is a function to update a single step of the process. This will take our grid as an input and produce the new grid as an output, as well as some extra information. Let's start by making another type synonym for OGrid as the "Octopus grid".

type OGrid = HashMap Coord2 Int

Now a simple version of this function would have a type signature like this:

runStep :: (MonadLogger m) => OGrid -> m OGrid

(As mentioned last time, I'm defaulting to using MonadLogger for most implementation details).

However, we'll include two extra outputs for this function. First, we want an Int for the number of flashes that occurred on this step. This will help us with the first part of the problem, where we are summing the number of flashes given a certain number of steps.

Second, we want a Bool indicating that all of them have flashed. This is easy to derive from the number of flashes and will be our terminal condition flag for the second part of the problem.

runStep :: (MonadLogger m) => OGrid -> m (OGrid, Int, Bool)

Now the first thing we can do while stepping is to increment everything. Once we've done that, it is easy to pick out the coordinates that ought be our "initial flashes" - all the items where the value is at least 10.

runStep :: (MonadLogger m) => OGrid -> m (OGrid, Int, Bool)
runStep = ...
  where
  -- Start by incrementing everything
    incrementedGrid = (+1) <$> inputGrid
    initialFlashes = fst <$> filter (\(_, x) -> x >= 10) (HM.toList incrementedGrid)

Now what do we do with our initial flashes to propagate them? Let's defer this to a helper function, processFlashes. This will be where we perform the BFS step recursively. Using BFS requires a queue and a visited set, so we'll want these as arguments to our processing function. Its result will be the final grid, updated with all the incrementing done by the flashes, as well as the final set of all flashes, including the original ones.

processFlashes :: (MonadLogger m) =>
  HashSet Coord2 -> Seq Coord2 -> OGrid -> m (HashSet Coord2, OGrid)

In calling this from our runStep function, we'll prepopulate the visited set and the queue with the initial group of flashes, as well as passing the "incremented" grid.

runStep :: (MonadLogger m) => OGrid -> m (OGrid, Int, Bool)
runStep = do
  (allFlashes, newGrid) <- processFlashes (HS.fromList initialFlashes) (Seq.fromList initialFlashes) incrementedGrid
  ...
  where
  -- Start by incrementing everything
    incrementedGrid = (+1) <$> inputGrid
    initialFlashes = fst <$> filter (\(_, x) -> x >= 10) (HM.toList incrementedGrid)

Now the last thing we need to do is count the total number of flashes and reset all flashes coordinates to 0 before returning. We can also compare the number of flashes to the size of the hash map to see if they all flashed.

runStep :: (MonadLogger m) => OGrid -> m (OGrid, Int, Bool)
runStep inputGrid = do
  (allFlashes, newGrid) <- processFlashes (HS.fromList initialFlashes) (Seq.fromList initialFlashes) incrementedGrid
  let numFlashes = HS.size allFlashes
  let finalGrid = foldl (\g c -> HM.insert c 0 g) newGrid allFlashes
  return (finalGrid, numFlashes, numFlashes == HM.size inputGrid)
  where
  -- Start by incrementing everything
    incrementedGrid = (+1) <$> inputGrid
    initialFlashes = fst <$> filter (\(_, x) -> x >= 10) (HM.toList incrementedGrid)

Processing Flashes

So now we need to do this flash processing! To re-iterate, this is a BFS problem. We have a queue of coordinates that are flashing. In order to process a single flash, we increment its neighbors and, if incrementing puts its energy over 9, add it to the back of the queue to be processed.

So our inputs are the sequence of coordinates to flash, the current grid, and a set of coordinates we've already visited (since we want to avoid "re-flashing" anything).

processFlashes :: (MonadLogger m) =>
  HashSet Coord2 -> Seq Coord2 -> OGrid -> m (HashSet Coord2, OGrid)

We'll start with a base case. If the queue is empty, we'll return the input grid and the current visited set.

import qualified Data.Sequence as Seq
import qualified Data.HashSet as HS
import qualified Data.HashMap.Strict as HM

processFlashes :: (MonadLogger m) =>
  HashSet Coord2 -> Seq Coord2 -> OGrid -> m (HashSet Coord2, OGrid)
processFlashes visited queue grid = case Seq.viewl queue of
  Seq.EmptyL -> return (visited, grid)
  ...

Now suppose we have a non-empty queue and we can pull off the top element. We'll start by getting all 8 neighboring coordinates in the grid and incrementing their values. There's no harm in re-incrementing coordinates that have flashed already, because we'll just reset everything

processFlashes :: (MonadLogger m) =>
  HashSet Coord2 -> Seq Coord2 -> OGrid -> m (HashSet Coord2, OGrid)
processFlashes visited queue grid = case Seq.viewl queue of
  Seq.EmptyL -> return (visited, grid)
  top Seq.:< rest -> do
    -- Get the 8 adjacent coordinates in the 2D grid
    let allNeighbors = getNeighbors8 grid top
        -- Increment the value of all neighbors
        newGrid = foldl (\g c -> HM.insert c ((g HM.! c) + 1) g) grid allNeighbors
        ...

Then we want to filter this neighbors list down to the neighbors we'll add to the queue. So we'll make a predicate shouldAdd that tells us if a neighboring coordinate is a.) at least energy level 9 (so incrementing it causes a flash) and b.) that it is not yet visited. This lets us construct our new visited set and the final queue.

processFlashes :: (MonadLogger m) =>
  HashSet Coord2 -> Seq Coord2 -> OGrid -> m (HashSet Coord2, OGrid)
processFlashes visited queue grid = case Seq.viewl queue of
  Seq.EmptyL -> return (visited, grid)
  top Seq.:< rest -> do
    let allNeighbors = getNeighbors8 grid top
        newGrid = foldl (\g c -> HM.insert c ((g HM.! c) + 1) g) grid allNeighbors
        neighborsToAdd = filter shouldAdd allNeighbors
        newVisited = foldl (flip HS.insert) visited neighborsToAdd
        newQueue = foldl (Seq.|>) rest neighborsToAdd
    ...
  where
    shouldAdd :: Coord2 -> Bool
    shouldAdd coord = grid HM.! coord >= 9 && not (HS.member coord visited)

And, the cherry on top, we just have to make our recursive call with the new values.

processFlashes :: (MonadLogger m) =>
  HashSet Coord2 -> Seq Coord2 -> OGrid -> m (HashSet Coord2, OGrid)
processFlashes visited queue grid = case Seq.viewl queue of
  Seq.EmptyL -> return (visited, grid)
  top Seq.:< rest -> do
    let allNeighbors = getNeighbors8 grid top
        newGrid = foldl (\g c -> HM.insert c ((g HM.! c) + 1) g) grid allNeighbors
        neighborsToAdd = filter shouldAdd allNeighbors
        newVisited = foldl (flip HS.insert) visited neighborsToAdd
        newQueue = foldl (Seq.|>) rest neighborsToAdd
    processFlashes newVisited newQueue newGrid
  where
    shouldAdd :: Coord2 -> Bool
    shouldAdd coord = grid HM.! coord >= 9 && not (HS.member coord visited)

With processing done, we have completed our function for running a sinigle step.

Easy Solution

Now that we can run a single step, all that's left is to answer the questions! For the first (easy) part, we just want to count the number of flashes that occur over 100 steps. This will follow a basic recursion pattern, where one of the arguments tells us how many steps are left. The stateful values that we're recursing on are the grid itself, which updates each step, and the sum of the number of flashes.

runStepCount :: (MonadLogger m) => Int -> (OGrid, Int) -> (OGrid, Int)

Let's start with a base case. When we have 0 steps left, we return the inputs as the result.

runStepCount :: (MonadLogger m) => Int -> (OGrid, Int) -> m (OGrid, Int)
runStepCount 0 results = return results
...

The recursive case is also quite easy. We invoke runStep to get the updated grid and the number of flashses, and then recurse with a reduced step count, adding the new flashes to our previous sum.

runStepCount :: (MonadLogger m) => Int -> (OGrid, Int) -> m (OGrid, Int)
runStepCount 0 results = return results
runStepCount i (grid, prevFlashes) = do
  (newGrid, flashCount, _) <- runStep grid
  runStepCount (i - 1) (newGrid, flashCount + prevFlashes)

And then we can call this from our "easy" entrypoint:

solveDay11Easy :: String -> IO (Maybe Int)
solveDay11Easy fp = do
  initialState <- parseFile parse2DDigitHashMap fp
  (_, numFlashes) <- runStdoutLoggingT $ runStepCount 100 (initialState, 0)
  return $ Just numFlashes

Hard Solution

For the second part of the problem, we want to find the first step where *all octopuses flash**. Obviously once they synchronize the first time, they'll remain synchronized forever after that. So we'll write a slightly different recursive function, this time counting up instead of down.

runTillAllFlash :: (MonadLogger m) => OGrid -> Int -> m Int
runTillAllFlash inputGrid thisStep = ...

Each time we run this function, we'll call runStep. The terminal condition is when the Bool flag we get from runStep becomes true. In this case, we return the current step value.

runTillAllFlash :: (MonadLogger m) => OGrid -> Int -> m Int
runTillAllFlash inputGrid thisStep = do
  (newGrid, _, allFlashed) <- runStep inputGrid
  if allFlashed
    then return thisStep
    ...

Otherwise, we just going to recurse, except with an incremented step count.

runTillAllFlash :: (MonadLogger m) => OGrid -> Int -> m Int
runTillAllFlash inputGrid thisStep = do
  (newGrid, _, allFlashed) <- runStep inputGrid
  if allFlashed
    then return thisStep
    else runTillAllFlash newGrid (thisStep + 1)

And once again, we wrap up by calling this function from our "hard" entrypoint.

solveDay11Hard :: String -> IO (Maybe Int)
solveDay11Hard fp = do
  initialState <- parseFile parse2DDigitHashMap fp
  firstAllFlash <- runStdoutLoggingT $ runTillAllFlash initialState 1
  return $ Just firstAllFlash

And now we're done! Our program should be able to solve both parts of the problem!

Conclusion

For the next couple articles, I'll be walking through these same problems, except in video format! So stay tuned for that, and make sure you're subscribed to the YouTube channel so you get notifications about it!

And if you're interested in staying up to date with all the latest news on Monday Morning Haskell, make sure to subscribe to our mailing list. This will get you our monthly newsletter, access to our resources page, and you'll also get special offers on all of our video courses!

Read More
James Bowen James Bowen

Advent of Code: Seven Segment Logic Puzzle

We're into the last quarter of the year, and this means Advent of Code is coming up again in a couple months! I'm hoping to do a lot of these problems in Haskell again and this time do up-to-date recaps. To prepare for this, I'm going back through my solutions from last year and trying to update them and come up with common helpers and patterns that will be useful this year.

You can follow me doing these implementation reviews on my stream, and you can take a look at my code on GitHub here!

Most of my blog posts for the next few weeks will recap some of these problems. I'll do written summaries of solutions as well as video summaries to see which are more clear. The written summaries will use the In-Depth Coding style, so get ready for a lot of code! As a final note, you'll notice my frequent use of MonadLogger, as I covered in this article. So let's get started!

Problem Statement

I'm going to start with Day 8 from last year, which I found to be an interesting problem because it was more of a logic puzzle than a traditional programming problem. The problem starts with the general concept of a seven segment display, a way of showing numbers on an electronic display (like scoreboards, for example).

We can label each of the seven segments like so, with letters "a" through "g":

Segments a-g:

     aaaa 
    b    c
    b    c
     dddd 
    e    f
    e    f
     gggg

If all seven segments are lit up, this indicates an 8. If only "c" and "f" are lit up, that's 1, and so on.

The puzzle input consists of lines with 10 "code" strings, and 4 "output" strings, separated by a pipe delimiter:

be cfbegad cbdgef fgaecd cgeb fdcge agebfd fecdb fabcd edb | fdgacbe cefdb cefbgd gcbe
edbfga begcd cbg gc gcadebf fbgde acbgfd abcde gfcbed gfec | fcgedb cgb dgebacf gc

The 10 code strings show a "re-wiring" of the seven segment display. On the first line, we see that be is present as a code string. Since only a "one" has length 2, we know that "b" and "e" each refer either to the "c" or "f" segment, since only those segments are lit up for "one". We can use similar lines of logic to fully determine the mapping of code characters to the original segment display.

Once we have this, we can decode each output string on the right side, get a four-digit number, and then add all of these up.

Solution Approach

When I first solved this problem over a year ago, I went through the effort of deriving a general function to decode any string based on the input codes, and then used this function

However, upon revisiting the problem, I realized it's quite a bit simpler. The length of the output to decode is obviously the first big branching point (as we'll see, "part 1" of the problem clues you on to this). Four of the numbers have unique lengths of "on" segments:

  1. 2 Segments = 1
  2. 3 Segments = 7
  3. 4 Segments = 4
  4. 7 Segments = 8

Then, three possible numbers have 5 "on" segments (2, 3, and 5). The remaining three (0, 6, 9) use six segments.

However, when it comes to solving these more ambiguous numbers, the key still lies with the digits 1 and 4, because we can always find the codes referring to these by their length. So we can figure out which two code characters are on the right side (referring to the c and fsegments) and which two segments refer to "four minus one", so segments b and d. We don't immediately know which is which in either pair, but it doesn't matter!

Between our "length 5" outputs (2, 3, 5), only 3 contains both segments from "one". So if that isn't true, we can then look at the "four minus one" segments (b and d), and if both are present, it's a 5, otherwise it's a 2.

We can employ similar logic for the length-6 possibilities. If either "one" segment is missing, it must be 6. Then if both "four minus one" segments are present, the answer is 9. Otherwise it is 0.

If this logic doesn't make sense in paragraphs, here's a picture that captures the essential branches of the logic.

So how do we turn this solution into code?

Utilities

First, let's start with a couple utility functions. These functions capture patterns that are useful across many different problems. The first of these is countWhere. This is a small helper whenever we have a list of items and we want the number of items that fulfill a certain predicate. This is a simple matter of filtering on the predicate and taking the length.

countWhere :: (a -> Bool) -> [a] -> Int
countWhere predicate list = length $ filter predicate list

Next we'll have a flexible parsing function. In general, I've been trying to use Megaparsec to parse the problem inputs (though it's often easier to parse them by hand). You can read this series to learn more about parsing in Haskell, and this part specifically for megaparsec.

But a good general helper we can have is "given a file where each line has a specific format, parse the file into a list of outputs." I refer to this function as parseLinesFromFile.

parseLinesFromFile :: (MonadIO m) => ParsecT Void Text m a -> FilePath -> m [a]
parseLinesFromFile parser filepath = do
  input <- pack <$> liftIO (readFile filepath)
  result <- runParserT (sepEndBy1 parser eol) "Utils.hs" input
  case result of
    Left e -> error $ "Failed to parse: " ++ show e
    Right x -> return x

Two key observations about this function. We take the parser as an input (this type is ParsecT Void Text m a). Then we apply it line-by-line using the flexible combinator sepEndBy1 and the eol parser for "end of line". The combinator means we parse several instances of the parser that are separated and optionally ended by the second parser. So each instance (except perhaps the last) of the input parser then is followed by an "end of line" character (or carriage return).

Parsing the Lines

Now when it comes to the specific problem solution, we always have to start by parsing the input from a file (at least that's how I prefer to do it). The first step of parsing is to determine what we're parsing into. What is the "output type" of parsing the data?

In this case, each line we parse consists of 10 "code" strings and 4 "output" strings. So we can make two types to hold each of these parts - InputCode and OutputCode.

data InputCode = InputCode
  { screen0 :: String
  , screen1 :: String
  , screen2 :: String
  , screen3 :: String
  , screen4 :: String
  , screen5 :: String
  , screen6 :: String
  , screen7 :: String
  , screen8 :: String
  , screen9 :: String
  } deriving (Show)

data OutputCode = OutputCode
  { output1 :: String
  , output2 :: String
  , output3 :: String
  , output4 :: String
  } deriving (Show)

Now each different code string can be captured by the parser some letterChar. If we wanted to be more specific, we could even do some like:

choice [char 'a', char 'b', char 'c', char 'd', char 'e', char 'f', char 'g']

Now for each group of strings, we'll parse them using the same sepEndBy1 combinator we used before. This time, the separator is hspace, covering horizontal space characters (including tabs, but not newlines). Between these, we use `string "| " to parse the bar in between the input line. So here's the start of our parser:

parseInputLine :: (MonadLogger m) => ParsecT Void Text m (Maybe (InputCode, OutputCode))
parseInputLine = do
  screenCodes <- sepEndBy1 (some letterChar) hspace
  string "| "
  outputCodes <- sepEndBy1 (some letterChar) hspace
  ...

Both screenCodes and outputCodes are lists, and we want to convert them into our output types. So first, we do some validation and ensure that the right number of strings are in each list. Then we can pattern match and group them properly. Invalid results give Nothing.

parseInputLine :: (MonadLogger m) => ParsecT Void Text m (Maybe (InputCode, OutputCode))
parseInputLine = do
  screenCodes <- sepEndBy1 (some letterChar) hspace
  string "| "
  outputCodes <- sepEndBy1 (some letterChar) hspace
  if length screenCodes /= 10 
    then lift (logErrorN $ "Didn't find 10 screen codes: " <> intercalate ", " (pack <$> screenCodes)) >> return Nothing
    else if length outputCodes /= 4
      then lift (logErrorN $ "Didn't find 4 output codes: " <> intercalate ", " (pack <$> outputCodes)) >> return Nothing
      else
        let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] = screenCodes
            [o1, o2, o3, o4] = outputCodes
        in  return $ Just (InputCode s0 s1 s2 s3 s4 s5 s6 s7 s8 s9, OutputCode o1 o2 o3 o4)

Then we can parse the codes using parseLinesFromFile, applying this in both the "easy" part and the "hard" part of the problem.

solveDay8Easy :: String -> IO (Maybe Int)
solveDay8Easy fp = runStdoutLoggingT $ do
  codes <- catMaybes <$> parseLinesFromFile parseInputLine fp
  ...

solveDay8Hard :: String -> IO (Maybe Int)
solveDay8Hard fp = runStdoutLoggingT $ do
  inputCodes <- catMaybes <$> parseLinesFromFile parseInputLine fp
  ...

The First Part

Now to complete the "easy" part of the problem, we have to answer the question: "In the output values, how many times do digits 1, 4, 7, or 8 appear?". As we've discussed, each of these has a unique length. So it's easy to first describe a function that can tell from an output string if it is one of these items:

isUniqueDigitCode :: String -> Bool
isUniqueDigitCode input = length input `elem` [2, 3, 4, 7]

Then we can use our countWhere utility to apply this function and figure out how many of these numbers are in each output code.

uniqueOutputs :: OutputCode -> Int
uniqueOutputs (OutputCode o1 o2 o3 o4) = countWhere isUniqueDigitCode [o1, o2, o3, o4]

Finally, we take the sum of these, applied across the outputs, to get our first answer:

solveDay8Easy :: String -> IO (Maybe Int)
solveDay8Easy fp = runStdoutLoggingT $ do
  codes <- catMaybes <$> parseLinesFromFile parseInputLine fp
  let result = sum $ uniqueOutputs <$> (snd <$> codes)
  return $ Just result

The Second Part

Now for the hard part! We have to decode each digit in the output, determine its value, and then get the value on the 4-digit display. The root of this is to decode a single string, given the InputCode of 10 values. So let's write a function that does that. We'll use MaybeT since there are some failure conditions on this function.

decodeString :: (MonadLogger m) => InputCode -> String -> MaybeT m Int

As we've discussed, the logic is easy for certain lengths. If the string length is 2, 3, 4 or 7, we have obvious answers.

decodeString :: (MonadLogger m) => InputCode -> String -> MaybeT m Int
decodeString inputCodes output
  | length output == 2 = return 1
  | length output == 3 = return 7
  | length output == 4 = return 4
  | length output == 7 = return 8
  ...

Now for length 5 and 6, we'll have separate functions:

decode5 :: (MonadLogger m) => InputCode -> String -> MaybeT m Int

decode6 :: (MonadLogger m) => InputCode -> String -> MaybeT m Int

Then we can call these from our base function:

decodeString :: (MonadLogger m) => InputCode -> String -> MaybeT m Int
decodeString inputCodes output
  | length output == 2 = return 1
  | length output == 3 = return 7
  | length output == 4 = return 4
  | length output == 7 = return 8
  | length output == 5 = decode5 inputCodes output
  | length output == 6 = decode6 inputCodes output
  | otherwise = mzero

We have a failure case of mzero if the length doesn't fall within our expectations for some reason.

Now before we can write decode5 and decode6, we'll write a helper function. This helper will determine the two characters present in the "one" segment as well as the two characters present in the "four minus one" segment.

For some reason I separated the two Chars for the "one" segment but kept them together for "four minus one". This probably isn't necessary. But anyways, here's our type signature:

sortInputCodes :: (MonadLogger m) => InputCode -> MaybeT m (Char, Char, String)
sortInputCodes ic@(InputCode c0 c1 c2 c3 c4 c5 c6 c7 c8 c9) = do
  ...

Let's start with some more validation. We'll sort the strings by length and ensure the length distributions are correct.

sortInputCodes :: (MonadLogger m) => InputCode -> MaybeT m (Char, Char, String)
sortInputCodes ic@(InputCode c0 c1 c2 c3 c4 c5 c6 c7 c8 c9) = do
  ...
  where
    [sc0, sc1,sc2,sc3,sc4,sc5,sc6,sc7,sc8,sc9] = sortOn length [c0, c1, c2, c3, c4, c5, c6, c7, c8, c9]
    validLengths =
      length sc0 == 2 && length sc1 == 3 && length sc2 == 4 &&
      length sc3 == 5 && length sc4 == 5 && length sc5 == 5 &&
      length sc6 == 6 && length sc7 == 6 && length sc8 == 6 &&
      length sc9 == 7

If the lengths aren't valid, we'll return mzero as a failure case again. But if they are, we'll pattern match to identify our characters for "one" and the string for "four". By deleting the "one" characters, we'll get a string for "four minus one". Then we can return all our items:

sortInputCodes :: (MonadLogger m) => InputCode -> MaybeT m (Char, Char, String)
sortInputCodes ic@(InputCode c0 c1 c2 c3 c4 c5 c6 c7 c8 c9) = do
  if not validLengths
    then logErrorN ("Invalid inputs: " <> (pack . show $ ic)) >> mzero
    else do
      let [sc01, sc02] = sc0
      let fourMinusOne = delete sc02 (delete sc01 sc2)
      return (sc01, sc02, fourMinusOne)
  where
    [sc0, sc1,sc2,sc3,sc4,sc5,sc6,sc7,sc8,sc9] = sortOn length [c0, c1, c2, c3, c4, c5, c6, c7, c8, c9]
    validLengths =
      length sc0 == 2 && length sc1 == 3 && length sc2 == 4 &&
      length sc3 == 5 && length sc4 == 5 && length sc5 == 5 &&
      length sc6 == 6 && length sc7 == 6 && length sc8 == 6 &&
      length sc9 == 7

Length 5 Logic

Now we're ready to decode a string of length 5! We start by sorting the inputs, and then picking out the three elements from the list that could be of length 5:

decode5 :: (MonadLogger m) => InputCode -> String -> MaybeT m Int
decode5 ic output = do
  (c01, c02, fourMinusOne) <- sortInputCodes ic
  ...

So first we'll check if the "one" characters are present, we get 3.

decode5 :: (MonadLogger m) => InputCode -> String -> MaybeT m Int
decode5 ic output = do
  (c01, c02, fourMinusOne) <- sortInputCodes ic
  -- If both from c0 are present, it's a 3
  if c01 `elem` output && c02 `elem` output
    then return 3
    else ...

Then if "four minus one" shares both its characters with the output, the answer is 5, otherwise it is 2.

decode5 :: (MonadLogger m) => InputCode -> String -> MaybeT m Int
decode5 ic output = do
  (c01, c02, fourMinusOne) <- sortInputCodes ic
  -- If both from c0 are present, it's a 3
  if c01 `elem` output && c02 `elem` output
    then return 3
    else do
      let shared = fourMinusOne `intersect` output
      if length shared == 2
        then return 5
        else return 2

Length 6 Logic

The logic for length 6 strings is very similar. I wrote it a little differently in this function, but the idea is the same.

decode6 :: (MonadLogger m) => InputCode -> String -> MaybeT m Int
decode6 ic output = do
  (c01, c02, fourMinusOne) <- sortInputCodes ic
  -- If not both from c0 are present, it's a 6
  if not (c01 `elem` output && c02 `elem` output)
    then return 6
    else do
      -- If both of these characters are present in output, 9 else 0
      if all (`elem` output) fourMinusOne then return 9 else return 0

Wrapping Up

Now that we can decode an output string, we just have to be able to do this for all strings in our output. We just multiply their values by the appropriate power of 10.

decodeAllOutputs :: (MonadLogger m) => (InputCode, OutputCode) -> MaybeT m Int
decodeAllOutputs (ic, OutputCode o1 o2 o3 o4) = do
  d01 <- decodeString ic o1
  d02 <- decodeString ic o2
  d03 <- decodeString ic o3
  d04 <- decodeString ic o4
  return $ d01 * 1000 + d02 * 100 + d03 * 10 + d04

And now we can complete our "hard" function by decoding all these inputs and taking their sums.

solveDay8Hard :: String -> IO (Maybe Int)
solveDay8Hard fp = runStdoutLoggingT $ do
  inputCodes <- catMaybes <$> parseLinesFromFile parseInputLine fp
  results <- runStdoutLoggingT $ runMaybeT (mapM decodeAllOutputs inputCodes)
  return $ fmap sum results

Conclusion

That's all for this week! You can take a look at all this code on GitHub if you want! Here's the main solution module!

Next time, we'll go through another one of these problems! If you'd like to stay up to date with the latest on Monday Morning Haskell, subscribe to our mailing list! This will give you access to all our subscriber resources!

Read More
James Bowen James Bowen

Dijkstra Comparison: Looking at the Library Function

In the last few articles I've gone through my approach for generalizing Dijkstra's algorithm in Haskell. The previous parts of this included:

  1. Simple Form of Dijkstra's Algorithm
  2. Generalizing with a Multi-Param Typeclass
  3. Generalizing with a Type Family
  4. A 2D Graph example

But of course I wasn't the first person to think about coming up with a general form of Dijkstra's algorithm in Haskell. Today, we'll look at the API for a library implementation of this algorithm and compare it to the implementations I thought up.

Comparing Types

So let's look at the type signature for the library version of Dijkstra's algorithm.

dijkstra ::
    (Foldable f, Num cost, Ord cost, Ord State)
  => (state -> f state) -- Function to generate list of neighbors
  -> (state -> state -> cost) -- Function for cost generation
  -> (state -> Bool) -- Destination predicate
  -> state -- Initial state
  -> Maybe (cost, [state]) -- Solution cost and path, Nothing if goal is unreachable

We'd like to compare this against the implementations in this series, but it's also useful to think about the second version of this function from the library: dijkstraAssoc. This version combines the functions for generating neighbors and costs:

dijkstraAssoc ::
    (Num cost, Ord cost, Ord state)
  => (state -> [(state, cost)])
  -> (state -> Bool)
  -> state
  -> Maybe (cost, [state])

And now here's the signature for the Multi-Param typeclass version I wrote:

findShortestDistance ::
    (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph)
  => graph -> node -> node -> Distance cost

We can start the comparison by pointing out a few surface-level differences..

First, the library uses Nothing to represent a failure to reach the destination, instead of a Distance type with Infinity. This is a sensible choice to spare API users from incorporating an internal type into their own code. However, it does make some of the internal code more cumbersome.

Second, the library function also includes the full path in its result which is, of course, very helpful most of the time. The implementation details for this aren't too complicated, but it requires tracking an extra structure, so I have omitted it so far in this series.

Third, the library function takes a predicate for its goal, instead of relying on an equality. This helps a lot with situations where you might have many potential destinations.

Functional vs. Object Oriented Design

But the main structural difference between our functions is, of course, the complete lack of a "graph" type in the library implementation! Our version provides the graph object and adds a typeclass constraint so that we can get the neighboring edges. The library version just includes this function as a separate argument to the main function.

Without meaning to, I created an implementation that is more "object oriented". That is, it is oriented around the "object" of the graph. The library implementation is more "functional" in that it relies on passing important information as higher order functions, rather than associating the function specifically with the graph object.

Clearly the library implementation is more in keeping with Haskell's functional nature. Perhaps my mind gravitated towards an object oriented approach because my day job involves C++ and Python.

But the advantages of the functional approach are clear. It's much easier to generalize an algorithm in terms of functions, rather than with objects. By removing the object from the equation entirely, it's one less item that needs to have a parameterized (or templated) type in our final solution.

However, this only works well when functions are first class items that we can pass as input parameters. Languages like C++ and Java have been moving in the direction of making this easier, but the syntax is not nearly as clean as Haskell's.

Partial function application also makes this a lot easier. If we have a function that is written in terms of our graph type, we can still use this with the library function (see the examples below!). It is most convenient if the graph is our first argument, and then we can partially apply the function and get the right input for, say, dijkstraAssoc.

Applying The Library Function

To close this article, let's see these library functions in action with our two basic examples. Recall our original graph type:

import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

newtype Graph = Graph
   { edges :: HashMap String [(String, Int)] }

graph1 :: Graph
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

To fill in our findShortestDistance function, we can easily use dijkstraAssoc, using the edges field of our object to supply the assoc function.

import Algorithm.Search (dijkstraAssoc)
import Data.Maybe (fromMaybe)

findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph start end = case answer of
  Just (dist, path) -> Dist dist
  Nothing -> Infinity
  where
    costFunction node = fromMaybe [] (HM.lookup node (edges graph))
    answer = dijkstraAssoc costFunction (== end) start

...

>> findShortestDistance graph1 "A" "D"
Distance 40

But of course, we can also just skip our in-between function and use the full output of the library function. This is a little more cumbersome in our case, but it still works.

>> let costFunction node = fromMaybe [] (HM.lookup node (edges graph1))
>> dikjstraAssoc costFunction (== "D") "A"
Just (40, ["C", "D"]

Notice that the library function omits the initial state when providing the final path.

Graph 2D Example

Now let's apply it to our 2D graph example as well. This time it's easier to use the original dijkstra function, rather than the version with "assoc" pairs.

import qualified Data.Array as A

newtype Graph2D = Graph2D (A.Array (Int, Int) Int)

getNeighbors :: A.Array (Int, Int) Int -> (Int, Int) -> [(Int, Int)]
getNeighbors input (row, col) = catMaybes [maybeUp, maybeDown, maybeLeft, maybeRight]
  where
    (maxRow, maxCol) = snd . A.bounds $ input
    maybeUp = if row > 0 then Just (row - 1, col) else Nothing
    maybeDown = if row < maxRow then Just (row + 1, col) else Nothing
    maybeLeft = if col > 0 then Just (row, col - 1) else Nothing
    maybeRight = if col < maxCol then Just (row, col + 1) else Nothing

graph2d :: Graph2D
graph2d = Graph2D $ A.listArray ((0, 0), (4, 4))
  [ 0, 2, 1, 3, 2
  , 1, 1, 8, 1, 4
  , 1, 8, 8, 8, 1
  , 1, 9, 9, 9, 1
  , 1, 4, 1, 9, 1
  ]

findShortestPath2D :: Graph2D -> (Int, Int) -> (Int, Int) -> Maybe (Int, [(Int, Int)])
findShortestPath2D (Graph2D graph) start end = dijkstra
  (getNeighbors graph)
  (\_ b -> graph A.! b)
  (== end)
  start

...

>> findShortestDistance2D graph2d (0, 0) (4, 4)
Just (14, [(0, 1), (0, 2), (0, 3), (1, 3), (1, 4), (2, 4), (3, 4), (4, 4)]

Conclusion

In the final part of this series we'll consider the "monadic" versions of these library functions and why someone would want to use them.

If you enjoyed this article, make sure to subscribe to our mailing list! This will keep you up to date on all the latest news about the site, as well as giving you access to our subscriber resources that will help you on your Haskell journey!

As always, you can find the full code implementation for this article on GitHub. But it is also given below in the appendix below.

Appendix - Full Code

module DijkstraLib where

import qualified Data.Array as A
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (catMaybes, fromMaybe)
import Algorithm.Search (dijkstra, dijkstraAssoc)

data Distance a = Dist a | Infinity
  deriving (Show)

newtype Graph = Graph
   { edges :: HashMap String [(String, Int)] }

graph1 :: Graph
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph start end = case answer of
  Just (dist, path) -> Dist dist
  Nothing -> Infinity
  where
    costFunction node = fromMaybe [] (HM.lookup node (edges graph))
    answer = dijkstraAssoc costFunction (== end) start

newtype Graph2D = Graph2D (A.Array (Int, Int) Int)

getNeighbors :: A.Array (Int, Int) Int -> (Int, Int) -> [(Int, Int)]
getNeighbors input (row, col) = catMaybes [maybeUp, maybeDown, maybeLeft, maybeRight]
  where
    (maxRow, maxCol) = snd . A.bounds $ input
    maybeUp = if row > 0 then Just (row - 1, col) else Nothing
    maybeDown = if row < maxRow then Just (row + 1, col) else Nothing
    maybeLeft = if col > 0 then Just (row, col - 1) else Nothing
    maybeRight = if col < maxCol then Just (row, col + 1) else Nothing

graph2d :: Graph2D
graph2d = Graph2D $ A.listArray ((0, 0), (4, 4))
  [ 0, 2, 1, 3, 2
  , 1, 1, 8, 1, 4
  , 1, 8, 8, 8, 1
  , 1, 9, 9, 9, 1
  , 1, 4, 1, 9, 1
  ]

findShortestPath2D :: Graph2D -> (Int, Int) -> (Int, Int) -> Maybe (Int, [(Int, Int)])
findShortestPath2D (Graph2D graph) start end = dijkstra
  (getNeighbors graph)
  (\_ b -> graph A.! b)
  (== end)
  start
Read More
James Bowen James Bowen

Dijkstra in a 2D Grid

We've now spent the last few articles looking at implementations of Dijkstra's algorithm in Haskell, with an emphasis on how to generalize the algorithm so it works for different graph types. Here's a quick summary in case you'd like to revisit some of this code, (since this article depends on these implementations).

Simple Implementation

Article

GitHub Code

Generalized with a Multi-param Typeclass

Article

GitHub Code

Generalized with a Type Family

Article

GitHub Code

Generalized Dijkstra Example

But now that we have a couple different examples of how we can generalize this algorithm, it's useful to actually see this generalization in action! Recall that our original implementation could only work with this narrowly defined Graph type:

newtype Graph = Graph
   { edges :: HashMap String [(String, Int)] }

We could add type parameters to make this slightly more general, but the structure remains the same.

newtype Graph node cost = Graph
   { edges :: HashMap node [(node, cost)] }

Suppose instead of this kind of explicit graph structure, we had a different kind of graph. Suppose we had a 2D grid of numbers to move through, and the "cost" of moving through each "node" was simply the value at that index in the grid. For example, we could have a grid like this:

[ [0, 2, 1, 3, 2]
  , [1, 1, 8, 1, 4]
  , [1, 8, 8, 8, 1]
  , [1, 9, 9, 9, 1]
  , [1, 4, 1, 9, 1]
  ]

The lowest cost path through this grid uses the following cells, for a total cost of 14:

[ [0, 2, 1, 3, x]
  , [x, x, x, 1, 4]
  , [x, x, x, x, 1]
  , [x, x, x, x, 1]
  , [x, x, x, x, 1]
  ]

We can make a "graph" type out of this grid in Haskell with a newtype wrapper over an Array. The index of our array will be a tuple of 2 integers, indicating row and column.

import qualified Data.Array as A

newtype Graph2D = Graph2D (A.Array (Int, Int) Int)

For simplicity, we'll assume that our array starts at (0, 0).

Getting the "Edges"

Because we now have the notion of a DijkstraGraph, all we need to do for this type to make it eligible for our shortest path function is make an instance of the class. The tricky part of this is the function for dijkstraEdges.

We'll start with a more generic function to get the "neighbors" of a cell in a 2D grid. Most cells will have 4 neighbors. But cells along the edge will have 3, and those in the corner will only have 2. We start such a function by defining our type signature and the bounds of the array.

getNeighbors :: A.Array (Int, Int) Int -> (Int, Int) -> [(Int, Int)]
getNieghbors input (row, col) = ...
  where
    (maxRow, maxCol) = snd . A.bounds $ input
    ...

Now we calculate the Maybe for a cell in each direction. We compare against the possible bounds in that direction and return Nothing if it's out of bounds.

getNeighbors :: A.Array (Int, Int) Int -> (Int, Int) -> [(Int, Int)]
getNeighbors input (row, col) = ...
  where
    (maxRow, maxCol) = snd . A.bounds $ input
    maybeUp = if row > 0 then Just (row - 1, col) else Nothing
    maybeDown = if row < maxRow then Just (row + 1, col) else Nothing
    maybeLeft = if col > 0 then Just (row, col - 1) else Nothing
    maybeRight = if col < maxCol then Just (row, col + 1) else Nothing

And as a last step, we use catMaybes to get all the "valid" neighbors.

getNeighbors :: A.Array (Int, Int) Int -> (Int, Int) -> [(Int, Int)]
getNeighbors input (row, col) = catMaybes [maybeUp, maybeDown, maybeLeft, maybeRight]
  where
    (maxRow, maxCol) = snd . A.bounds $ input
    maybeUp = if row > 0 then Just (row - 1, col) else Nothing
    maybeDown = if row < maxRow then Just (row + 1, col) else Nothing
    maybeLeft = if col > 0 then Just (row, col - 1) else Nothing
    maybeRight = if col < maxCol then Just (row, col + 1) else Nothing

Writing Class Instances

With this function, it becomes very easy to fill in our class instances! Let's start with the Multi-param class. We have to start by specifying the node type, and the cost type. As usual, our cost is a simple Int. But the node in this case is the index of our array - a tuple.

{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}

instance DijkstraGraph Graph2D (Int, Int) Int where
  ...

To complete the instance, we get our neighbors and combine them with the distance, which is calculated strictly from accessing the array at the index.

instance DijkstraGraph Graph2D (Int, Int) Int where
  dijkstraEdges (Graph2D arr) cell = [(n, arr A.! n) | n <- neighbors]
    where
      neighbors = getNeighbors arr cell

Now we can find the shortest distance! As before though, we have to be more explicit with certain types, as inference doesn't seem to work as well with these multi-param typeclasses.

dijkstraInput2D :: Graph2D
dijkstraInput2D = Graph2D $ A.listArray ((0, 0), (4, 4))
  [ 0, 2, 1, 3, 2
  , 1, 1, 8, 1, 4
  , 1, 8, 8, 8, 1
  , 1, 9, 9, 9, 1
  , 1, 4, 1, 9, 1
  ]

-- Dist 14
cost2 :: Distance Int
cost2 = findShortestDistance dijkstraInput2D (0 :: Int, 0 :: Int) (4, 4)

Type Family Instance

Filling in the type family version is essentially the same. All that's different is listing the node and cost types inside the definition instead of using separate parameters.

{-# LANGUAGE TypeFamilies #-}

instance DijkstraGraph Graph2D where
  type DijkstraNode Graph2D = (Int, Int)
  type DijkstraCost Graph2D = Int
  dijkstraEdges (Graph2D arr) cell = [(n, arr A.! n) | n <- neighbors]
    where
      neighbors = getNeighbors arr cell

And calling our shortest path function works here as well, this time without needing extra type specifications.

dijkstraInput2D :: Graph2D
dijkstraInput2D = Graph2D $ A.listArray ((0, 0), (4, 4))
  [ 0, 2, 1, 3, 2
  , 1, 1, 8, 1, 4
  , 1, 8, 8, 8, 1
  , 1, 9, 9, 9, 1
  , 1, 4, 1, 9, 1
  ]

-- Dist 14
cost3 :: Distance Int
cost3 = findShortestDistance dijkstraInput2D (0, 0) (4, 4)

Conclusion

Next time, we'll look at an even more complicated example for this problem. In the meantime, make sure you subscribe to our mailing list so you can stay up to date with the latest news!

As usual, the full code is in the appendix below. Note though that it depends on code from our previous parts: Dijkstra 2 (the Multi-param typeclass implementation) and Dijkstra 3, the version with a type family.

For the next part of this series, we'll compare this implementation with an existing library function!

Appendix

You can also find this code right here on GitHub.

{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}

module Graph2D where

import qualified Data.Array as A
import Data.Maybe (catMaybes)

import qualified Dijkstra2 as D2
import qualified Dijkstra3 as D3

newtype Graph2D = Graph2D (A.Array (Int, Int) Int)

instance D2.DijkstraGraph Graph2D (Int, Int) Int where
  dijkstraEdges (Graph2D arr) cell = [(n, arr A.! n) | n <- neighbors]
    where
      neighbors = getNeighbors arr cell

instance D3.DijkstraGraph Graph2D where
  type DijkstraNode Graph2D = (Int, Int)
  type DijkstraCost Graph2D = Int
  dijkstraEdges (Graph2D arr) cell = [(n, arr A.! n) | n <- neighbors]
    where
      neighbors = getNeighbors arr cell

getNeighbors :: A.Array (Int, Int) Int -> (Int, Int) -> [(Int, Int)]
getNeighbors input (row, col) = catMaybes [maybeUp, maybeDown, maybeLeft, maybeRight]
  where
    (maxRow, maxCol) = snd . A.bounds $ input
    maybeUp = if row > 0 then Just (row - 1, col) else Nothing
    maybeDown = if row < maxRow then Just (row + 1, col) else Nothing
    maybeLeft = if col > 0 then Just (row, col - 1) else Nothing
    maybeRight = if col < maxCol then Just (row, col + 1) else Nothing

dijkstraInput2D :: Graph2D
dijkstraInput2D = Graph2D $ A.listArray ((0, 0), (4, 4))
  [ 0, 2, 1, 3, 2
  , 1, 1, 8, 1, 4
  , 1, 8, 8, 8, 1
  , 1, 9, 9, 9, 1
  , 1, 4, 1, 9, 1
  ]

cost2 :: D2.Distance Int
cost2 = D2.findShortestDistance dijkstraInput2D (0 :: Int, 0 :: Int) (4, 4)

cost3 :: D3.Distance Int
cost3 = D3.findShortestDistance dijkstraInput2D (0, 0) (4, 4)
Read More
James Bowen James Bowen

Dijkstra with Type Families

In the previous part of this series, I wrote about a more general form of Dijkstra’s Algorithm in Haskell. This implementation used a Multi-param Typeclass to encapsulate the behavior of a “graph” type in terms of its node type and the cost/distance type. This is a perfectly valid implementation, but I wanted to go one step further and try out a different approach to generalization.

Type Holes

In effect, I view this algorithm as having three “type holes”. In order to create a generalized algorithm, we want to allow the user to specify their own “graph” type. But we need to be able to refer to two related types (the node and the cost) in order to make our algorithm work. So we can regard each of these as a “hole” in our algorithm.

The multi-param typeclass fills all three of these holes at the “top level” of the instance definition.

class DijkstraGraph graph node cost where
    ...

But there’s a different way to approach this pattern of “one general class and two related types”. This is to use a type family.

Type Families

A type family is an extension of a typeclass. While a typeclass allows you to specify functions and expressions related to the “target” type of the class, a type family allows you to associate other types with the target type. We start our definition the same way we would with a normal typeclass, except we’ll need a special compiler extension:

{-# LANGUAGE TypeFamilies #-}

class DijkstraGraph graph where
  ...

So we’re only specifying graph as the “target type” of this class. We can then specify names for different types associated with this class using the type keyword.

class DijkstraGraph graph where
  type DijkstraNode graph :: *
  type DijkstraCost graph :: *
  ...

For each type, instead of specifying a type signature after the colons (::), we specify the kind of that type. We just want base-level types for these, so they are *. (As a different example, a monad type like IO would have the kind * -> * because it takes a type parameter).

Once we’ve specified these type names, we can then use them within type signatures for functions in the class. So the last piece we need is the edges function, which we can write in terms of our types.

class DijkstraGraph graph where
  type DijkstraNode graph :: *
  type DijkstraCost graph :: *
  dijkstraEdges :: graph -> DijkstraNode graph -> [(DijkstraNode graph, DijkstraCost graph)]

Making an Instance of the Type Family

It’s not hard now to make an instance for this class, using our existing Graph String Int concept. We again use the type keyword to specify that we are filling in the type holes, and define the edges function as we have before:

{-# LANGUAGE FlexibleInstances #-}

instance DijkstraGraph (Graph String Int) where
    type DijkstraNode (Graph String Int) = String
    type DijkstraCost (Graph String Int) = Int
    dijkstraEdges graph node = fromMaybe [] (HM.lookup node (edges graph))

This hasn’t fixed the “re-statement of parameters” issue I mentioned last time. In fact we now restate them twice. But with a more general type, we wouldn’t necessarily have to parameterize our Graph type anymore.

Updating the Type Signature

The last thing to do now is update the type signatures within our function. Instead of using separate graph, node, and cost parameters, we’ll just use one parameter g for the graph, and define everything in terms of g, DijkstraNode g, and DijkstraCost g.

First, let’s remind ourselves what this looked like for the multi-param typeclass version:

findShortestDistance ::
  forall graph node cost. (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph node cost) =>
  graph -> node -> node -> Distance cost

And now with the type family version:

findShortestDistance ::
  forall g. (Hashable (DijkstraNode g), Eq (DijkstraNode g), Num (DijkstraCost g), Ord (DijkstraCost g), DijkstraGraph g) =>
  g -> DijkstraNode g -> DijkstraNode g -> Distance (DijkstraCost g)

The “simpler” typeclass ends up being shorter. But the type family version actually has fewer type arguments (just g instead of graph, node, and cost). It’s up to you which you prefer.

And don’t forget, type signatures within the function will also need to change:

processQueue ::
  DijkstraState (DijkstraNode g) (DijkstraCost g) ->
  HashMap (DijkstraNode g) (Distance (DijkstraCost g))

Aside from that, the rest of this function works!

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

...

>> findShortestDistance graph1 “A” “D”
Dist 40

Unlike the multi-param typeclass version, this one has no need of specifying the final result type in the expression. Type inference seems to work better here.

Conclusion

Is this version better than the multi-param typeclass version? Maybe, maybe not, depending on your taste. It has some definite weaknesses in that type families are more of a foreign concept to more Haskellers, and they require more language extensions. The type signatures are also a bit more cumbersome. But, in my opinion, the semantics are more correct by making the graph type the primary target and the node and cost types as simply “associated” types.

In the next part of this series, we’ll apply these different approaches to some different graph types. This will demonstrate that the approach is truly general and can be used for many different problems!

Below you can find the full code, or you can follow this link to see everything on GitHub!

Appendix - Full Code

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Dijkstra3 where

import Data.Hashable (Hashable)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

newtype Graph node cost = Graph
   { edges :: HashMap node [(node, cost)] }

class DijkstraGraph graph where
  type DijkstraNode graph :: *
  type DijkstraCost graph :: *
  dijkstraEdges :: graph -> DijkstraNode graph -> [(DijkstraNode graph, DijkstraCost graph)]

instance DijkstraGraph (Graph String Int) where
    type DijkstraNode (Graph String Int) = String
    type DijkstraCost (Graph String Int) = Int
    dijkstraEdges graph node = fromMaybe [] (HM.lookup node (edges graph))

data DijkstraState node cost = DijkstraState
  { visitedSet :: HashSet node
  , distanceMap :: HashMap node (Distance cost)
  , nodeQueue :: MinPrioHeap (Distance cost) node
  }

findShortestDistance :: forall g. (Hashable (DijkstraNode g), Eq (DijkstraNode g), Num (DijkstraCost g), Ord (DijkstraCost g), DijkstraGraph g) => g -> DijkstraNode g -> DijkstraNode g -> Distance (DijkstraCost g)
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue

    processQueue :: DijkstraState (DijkstraNode g) (DijkstraCost g) -> HashMap (DijkstraNode g) (Distance (DijkstraCost g))
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = dijkstraEdges graph node
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
    foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
      let altDistance = addDist (d0 !?? current) (Dist cost)
      in  if altDistance < d0 !?? neighborNode
            then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
            else ds

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]
Read More
James Bowen James Bowen

Generalizing Dijkstra's Algorithm

Earlier this week, I wrote a simplified implementation of Dijkstra’s algorithm. You can read the article for details or just look at the full code implementation here on GitHub. This implementation is fine to use within a particular project for a particular purpose, but it doesn’t generalize very well. Today we’ll explore how to make this idea more general.

I chose to do this without looking at any existing implementations of Dijkstra’s algorithm in Haskell libraries to see how my approach would be different. So at the end of this series I’ll also spend some time comparing my approach to some other ideas that exist in the Haskell world.

Parameterizing the Graph Type

So why doesn’t this approach generalize? Well, for the obvious reason that my module defines a specific type for the graph:

data Graph = Graph
  { graphEdges :: HashMap String [(String, Int)] }

So, for someone else to re-use this code from the perspective of a different project, they would have to take whatever graph information they had, and turn it into this specific type. And their data might not map very well into String values for the nodes, and they might also have a different cost type in mind than the simple Int value, with Double being the most obvious example.

So we could parameterize the graph type and allow more customization of the underlying values.

data Graph node cost = Graph
  { graphEdges :: HashMap node [(node, cost)] }

The function signature would have to change to reflect this, and we would have to impose some additional constraints on these types:

findShortestDistance :: (Hashable node, Eq node, Num cost, Ord cost) =>
  Graph node cost  -> node -> node -> Distance cost

Graph Enumeration

But this would still leave us with an important problem. Sometimes you don’t want to have to enumerate the whole graph. As is, the expression you submit as the "graph" to the function must have every edge enumerated, or it won’t give you the right answer. But many times, you won’t want to list every edge because they are so numerous. Rather, you want to be able to list every edge simply from a particular node. For example:

edgesForNode :: Graph node cost -> node -> [(node, cost)]

How can we capture this behavior more generally in Haskell?

Using a Typeclass

Well one of the tools we typically turn to for this task is a typeclass. We might want to define something like this:

class DijkstraGraph graph where
  dijkstraEdges :: graph node cost -> node -> [(node, cost)]

However, it quickly gets a bit strange to try to do this with a simple typeclass because of the node and cost parameters. It’s difficult to resolve the constraints we end up needing because these parameters aren’t really part of our class.

Using a Multi-Param Typeclass

We could instead try having a multi-param typeclass like this:

{-# LANGUAGE MultiParamTypeClasses #-}

class DijkstraGraph graph node cost where
  dijkstraEdges :: graph -> node -> [(node, cost)]

This actually works more smoothly than the previous try. We can construct an instance (if we allow flexible instances).

{-# LANGUAGE FlexibleInstances #-}

import qualified Data.HashMap as HM
import Data.Maybe (fromMaybe)

instance DijkstraGraph (Graph String Int) String Int where
  dijkstraEdges g n = fromMaybe [] (HM.lookup n (edges g))

And we can actually use this class in our function now! It mainly requires changing around a few of our type signatures. We can start with our DijkstraState type, which must now be parameterized by the node and cost:

data DijkstraState node cost = DijkstraState
  { visitedSet :: HashSet node
  , distanceMap :: HashMap node (Distance cost)
  , nodeQueue :: MinPrioHeap (Distance cost) node
  }

And, of course, we would also like to generalize the type signature of our findShortestDistance function. In its simplest form, we would like use this:

findShortestDistance :: graph -> node -> node -> Distance cost

However, a couple extra items are necessary to make this work. First, as above, our function is the correct place to assign constraints to the node and cost types. The node type must fit into our hashing structures, so it should fulfill Eq and Hashable. The cost type must be Ord and Num in order for us to perform our addition operations and use it for the heap. And last of course, we have to add the constraint regarding the DijkstraGraph itself:

findShortestDistance ::
  (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph) =>
  graph -> node -> node -> Distance cost

Now, if we want to use the graph, node, and cost types within the “inner” type signatures of our function, we need one more thing. We need a forall specifier on the function so that the compiler knows we are referring to the same types.

{-# LANGUAGE ScopedTypeVariables #-}

findShortestDistance :: forall graph node cost.
  (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph) =>
  graph -> node -> node -> Distance cost

We can now make one change to our function so that it works with our class.

processQueue :: DijkstraState node cost -> HashMap node (Distance cost)
processQueue = ...
  -- Previously
  -- allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
  -- Updated
  allNeighbors = dijkstraEdges graph node

And now we’re done! We can again, verify the behavior. However, we do run into some difficulties in that we need some extra type specifiers to help the compiler figure everything out.

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

...

>> :set -XFlexibleContexts
>> findShortestDistance graph1 :: Distance Int
Dist 40

Conclusion

Below in the appendix is the full code for this part. You can also take a look at it on Github here.

For various reasons, I don’t love this attempt at generalizing. I especially don't like the "re-statement" of the parameter types in the instance. The parameters are part of the Graph type and are separately parameters of the class. This is what leads to the necessity of specifying the Distance Int type in the GHCI session above. We could avoid this if we don't parameterize our Graph type, which is definitely an option.

In the next part of this series, we'll make a second attempt at generalizing this algorithm!

Appendix

{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Dijkstra2 where

import Data.Hashable (Hashable)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

newtype Graph node cost = Graph
   { edges :: HashMap node [(node, cost)] }

class DijkstraGraph graph node cost where
    dijkstraEdges :: graph -> node -> [(node, cost)]

instance DijkstraGraph (Graph String Int) String Int where
    dijkstraEdges g n = fromMaybe [] (HM.lookup n (edges g))

data DijkstraState node cost = DijkstraState
  { visitedSet :: HashSet node
  , distanceMap :: HashMap node (Distance cost)
  , nodeQueue :: MinPrioHeap (Distance cost) node
  }

findShortestDistance :: forall graph node cost. (Hashable node, Eq node, Num cost, Ord cost, DijkstraGraph graph node cost) => graph -> node -> node -> Distance cost
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue

    processQueue :: DijkstraState node cost -> HashMap node (Distance cost)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = dijkstraEdges graph node
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
    foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
      let altDistance = addDist (d0 !?? current) (Dist cost)
      in  if altDistance < d0 !?? neighborNode
            then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
            else ds

graph1 :: Graph String Int
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]
Read More
James Bowen James Bowen

Dijkstra's Algorithm in Haskell

In some of my recent streaming sessions (some of which you can see on my YouTube chanel), I spent some time playing around with Dijkstra’s algorithm. I wrote my own version of it in Haskell, tried to generalize it to work in different settings, and then used it in some examples. So for the next couple weeks I’ll be writing about those results. Today I’ll start though with a quick overview of a basic Haskell approach to the problem.

Note: This article will follow the “In Depth” reading style I talked about last week. I’ll be including all the details of my code, so if you want to follow along with this article, everything should compile and work! I’ll list dependencies, imports, and the complete code in an appendix at the end.

Pseudocode

Before we can understand how to write this algorithm in Haskell specifically, we need to take a quick look at the pseudo code. This is adapted from the Wikipedia description

function Dijkstra(Graph, source):

      for each vertex v in Graph.Vertices:
          dist[v] <- INFINITY
          add v to Q
      dist[source] <- 0

      while Q is not empty:
          u <- vertex in Q with min dist[u]
          remove u from Q

          for each neighbor v of u still in Q:
              alt <- dist[u] + Graph.Edges(u, v)
              if alt < dist[v] and dist[u] is not INFINITY:
                  dist[v] <- alt

      return dist[]

There are a few noteworthy items here. This code references two main structures. We have dist, the mapping of nodes to distances. There is also Q, which has a special operation “vertex in Q with min dist[u]”. It also has the operation “still in Q”. We can actually separate this into two items. We can have one structure to track the minimum distance to nodes, and then we have a second to track which ones are “visited”.

With this in mind, we can break the "Dijkstra Process" into 5 distinct steps.

  1. Define our type signature
  2. Initialize a structure with the different items (Q, dist, etc.) in their initial states
  3. Write a loop for processing each element from the queue.
  4. Write an inner loop for processing each “neighbor” we encounter of the items pulled from the queue.
  5. Get our answer from the final structure.

These steps will help us organize our code. Before we dive into the algorithm itself though, we’ll want a few helpers!

Helpers

There are a few specific items that aren’t really part of the algorithm, but they’ll make things a lot smoother for us. First, we’re going to define a new “Distance” type. This will include a general numeric type for the distance but also include an “Infinity” constructor, which will help us represent the idea of an unreachable value.

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

We want to ensure this type has a reasonable ordering, and that we can add values with it in a sensible way. If we rely on a simple integer type maxBound, we’ll end up with arithmetic overflow problems.

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

Now we’ll be tracking our distances with a Hash Map. So as an extra convenience, we’ll add an operator that will look up a particular item in our map, returning its distance if that exists, but otherwise returning “Infinity” if it does not.

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

Type Signature

Now we move on to the next step in our process: defining a type signature. To start, we need to ask, "what kind of graph are we working with?" We’ll generalize this in the future, but for now let’s assume we are defining our graph entirely as a map of “Nodes” (represented by string names) to “Edges”, which are tuples of names and costs for the distance.

newtype Graph = Graph
   { edges :: HashMap String [(String, Int)] }

Our dijkstra function will take such a graph, a starting point (the source), and an ending point (the destination) as its parameters. It will return the “Distance” from the start to the end (which could be infinite). For now, we’ll exclude returning the full path, but we’ll get to that by the end of the series.

findShortestDistance :: Graph -> String -> String -> Distance Int

With our graph defined more clearly now, we’ll want to define one more “stateful” type to use within our algorithm. From reading the pseudo code, we want this to contain three structures that vary with each iteration.

  1. A set of nodes we’ve visited.
  2. The distance map, from nodes to their “current” distance values
  3. A queue allowing us to find the “unvisited node with the smallest current distance”.

For the first two items, it's straightforward to see what types we use. A HashSet of strings will suffice for the visited set, and likewise a HashMap will help us track distances. For the queue, we need to be a little clever, but a priority heap using the distance and the node will be most helpful.

data DijkstraState = DijkstraState
  { visitedSet :: HashSet String
  , distanceMap :: HashMap String (Distance Int)
  , nodeQueue :: MinPrioHeap (Distance Int) String
  }

Initializing Values

Now for step 2, let’s build our initial state, taking the form of the DijkstraState type we defined above. Initially, we will consider that there are no visited nodes. Then we’ll define that the only distance we have is a distance of “0” to the source node. We’ll also want to store this pair in the queue, so that the source is the first node we pull out of our queue.

findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph src dest = ...
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue
    ...

Processing the Queue

Now we’ll write a looping function that will process the elements in our queue. This function will return for us the mapping of nodes to distances. It will take the current DijkstraState as its input. Remember that the most basic method we have of looping in Haskell, particularly when it comes to a “while” loop, is a recursive function.

Every recursive function needs at least one base case. So let’s start with one of those. If the queue is empty, we can return the map as it is.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      ...

Next there are cases of the queue containing at least one element. Suppose this element is our destination. We can also return the distance map immediately here, as it will already contain the distance to that point.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else ...

One last base case: if the node is already visited, then we can immediately recurse, except plugging in the new queue q1 for the old queue.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else ...

Now, on to the recursive case. In this case we will do 3 things.

  1. Pull a new node from our heap and consider that node “visited”
  2. Get all the “neighbors” of this node
  3. Process each neighbor and update its distance

Most of the work in step 3 will happen in our “inner loop”. The basics for the first two steps are quite easy.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors

Now we just need to process each neighbor. We can do this using a fold. Our “folding function” will have a type signature that incorporates the current node as a “fixed” argument while otherwise following the a -> b -> a pattern of a left fold. Each step will incorporate a new node with its cost and update the DijkstraState. This means the a value in our folding function is DijkstraState, while the b value is (String, Int).

foldNeighbor :: String -> DijkstraState -> (String, Int) -> DijkstraState

With this type signature set, we can now complete our processQueue function before implementing this inner loop. We call a foldl over the new neighbors, and then we recurse over the resulting DijkstraState.

findShortestDistance graph src dest = ...
  where
    ...
    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      ...
        else
          -- Update the visited set
          let v1 = HS.insert coord v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors

The Final Fold

Now let’s write this final fold, our “inner loop” function foldNeighbor. The core job of this function is to calculate the “alternative” distance to the given neighbor by going “through” the current node. This consists of taking the distance from the source to the current node (which is stored in the distance map) and adding it to the specific edge cost from the current to this new node.

foldNeighbor :: String -> DijkstraState -> (String, Int) -> DijkstraState
foldNeighbor current (DijkstraState v1 d0 q1) (neighborNode, cost) =
  let altDistance = addDist (d0 !?? current) (Dist cost)
  ...

We can then compare this distance to the existing distance we have to the neighbor in our distance map (or Infinity if it doesn’t exist, remember).

foldNeighbor current ds@(DijkstraState _ d0 _) (neighborNode, cost) =
  let altDistance = addDist (d0 !?? current) (Dist cost)
  in  if altDistance < d0 !?? neighborNode
  ...

If the alternative distance is smaller, we update the distance map by associating the neighbor node with the alternative distance and return the new DijkstraState. We also insert the new distance into our queue. If the alternative distance is not better, we make no changes, and return the original state.

foldNeighbor current ds@(DijkstraState _ d0 _) (neighborNode, cost) =
  let altDistance = addDist (d0 !?? current) (Dist cost)
  in  if altDistance < d0 !?? neighborNode
        then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
        else ds

Now we are essentially done! All that’s left to do is run the process queue function from the top level and get the distance for the destination point.

findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialState = ...
    processQueue = ...

Our code is complete, and we can construct a simple example and see that it works!

graph1 :: Graph
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]

...

{- GHCI -}
>> findShortestDistance graph1 “A” “D”
40

Next Steps

This algorithm is nice, but it’s very limited and specific to how we’ve constructed the Graph type. What if we wanted to expose a more general API to users?

Read on to see how to do this in the next part of this series!

Appendix

Here is the complete code for this article, including imports!

module DijkstraSimple where

{- required packages:
   containers, unordered-containers, hashable
-}

import qualified Data.HashMap.Strict as HM
import Data.HashMap.Strict (HashMap)
import qualified Data.Heap as H
import Data.Heap (MinPrioHeap)
import qualified Data.HashSet as HS
import Data.HashSet (HashSet)
import Data.Hashable (Hashable)
import Data.Maybe (fromMaybe)

data Distance a = Dist a | Infinity
  deriving (Show, Eq)

instance (Ord a) => Ord (Distance a) where
  Infinity <= Infinity = True
  Infinity <= Dist x = False
  Dist x <= Infinity = True
  Dist x <= Dist y = x <= y

addDist :: (Num a) => Distance a -> Distance a -> Distance a
addDist (Dist x) (Dist y) = Dist (x + y)
addDist _ _ = Infinity

(!??) :: (Hashable k, Eq k) => HashMap k (Distance d) -> k -> Distance d
(!??) distanceMap key = fromMaybe Infinity (HM.lookup key distanceMap)

newtype Graph = Graph
   { edges :: HashMap String [(String, Int)] }

data DijkstraState = DijkstraState
  { visitedSet :: HashSet String
  , distanceMap :: HashMap String (Distance Int)
  , nodeQueue :: MinPrioHeap (Distance Int) String
  }


findShortestDistance :: Graph -> String -> String -> Distance Int
findShortestDistance graph src dest = processQueue initialState !?? dest
  where
    initialVisited = HS.empty
    initialDistances = HM.singleton src (Dist 0)
    initialQueue = H.fromList [(Dist 0, src)]
    initialState = DijkstraState initialVisited initialDistances initialQueue

    processQueue :: DijkstraState -> HashMap String (Distance Int)
    processQueue ds@(DijkstraState v0 d0 q0) = case H.view q0 of
      Nothing -> d0
      Just ((minDist, node), q1) -> if node == dest then d0
        else if HS.member node v0 then processQueue (ds {nodeQueue = q1})
        else
          -- Update the visited set
          let v1 = HS.insert node v0
          -- Get all unvisited neighbors of our current node
              allNeighbors = fromMaybe [] (HM.lookup node (edges graph))
              unvisitedNeighbors = filter (\(n, _) -> not (HS.member n v1)) allNeighbors
          -- Fold each neighbor and recursively process the queue
          in  processQueue $ foldl (foldNeighbor node) (DijkstraState v1 d0 q1) unvisitedNeighbors
    foldNeighbor current ds@(DijkstraState v1 d0 q1) (neighborNode, cost) =
      let altDistance = addDist (d0 !?? current) (Dist cost)
      in  if altDistance < d0 !?? neighborNode
            then DijkstraState v1 (HM.insert neighborNode altDistance d0) (H.insert (altDistance, neighborNode) q1)
            else ds

graph1 :: Graph
graph1 = Graph $ HM.fromList
  [ ("A", [("D", 100), ("B", 1), ("C", 20)])
  , ("B", [("D", 50)])
  , ("C", [("D", 20)])
  , ("D", [])
  ]
Read More