Stock Market Shark: More Multidimensional DP
Today will be the final problem we do (for now) comparing Rust and Haskell LeetCode solutions. We’ll do a wrap-up of some of the important lessons next week. Last week’s problem was a multi-dimensional dynamic programming problem where the “dimensions” were obvious. We were working in 2D space trying to find the largest square, so we wanted the cells in our “DP” grid to correspond to the cells in our input grid.
Today we’ll solve one final problem using DP in multiple dimensions where the dimensions aren’t quite as obvious. To learn more about the basics behind implementing DP in Haskell, you need to enroll in our course, Solve.hs! You’ll learn many principles about algorithms in Module 3 and get a ton of practice with our exercises!
The Problem
Today’s problem is Best Time to Buy and Sell Stack IV, the final in a series of problems where we are aiming to maximize the profit we can make from purchasing a single stock.
We have two problem inputs. The first is an array of the prices of the stock over a number of days. Each day has one price. There is no fluctuation over the course of a day (real world stock trading would be much easier if we got this kind of future data!).
Our second input is a number of “transactions” we can make. A single transaction consists of buying AND selling the stock. There are some restrictions on how these transactions work. The primary one is that we cannot have simultaneous transactions. Another way of saying this is that we can only hold one “instance” of the stock at a time. We can’t buy one instance of the stock on day 1, and then another instance on day 2, and then sell them both later.
We also cannot sell a stock on the same day we buy it, nor buy a new instance on the same day we sell a previous instance. This isn’t so much a problem constraint as an algorithmic insight that there is no benefit to us doing this. Buying and selling on the same day yields no net profit, so we may as well just not use the transaction.
As an example, suppose we have 3 transactions to use, and the following data for the upcoming days:
[1, 4, 8, 2, 7, 1, 15]
The solution here is 26, via the following transactions:
- Buy the stock on day 1 for $1, sell it on day 3 for $8 ($7 profit)
- Buy the stock on day 4 for $2, sell it on day 5 for $7 ($5 profit)
- Buy the stock on day 6 for $1, sell it on day 7 for $15 ($14 profit)
If we only had 2 transactions to work with, the answer would be 21. We would simply omit the second transaction.
The Algorithm
Since this is a “hard” problem, the algorithm description is a bit tricky! But we can break it into a few pieces.
Grid Structure
As I alluded to, this is a multi-dimensional DP problem, but the “dimensions” are not as clear as our last problem, because this problem doesn’t have a spatial nature. But once you do enough DP problems, it gets easier to see what the dimensions are.
One dimension will be the “current day”, and the other will be the “transaction state”. The cell {s, d} will indicate “Given I am in state s on day d, what is the largest additional profit I can achieve?”
The number of days is obviously equal to the size of our input array. This will be our column dimension. So column i will always mean “if I am in this state on day i”.
The number of transaction states is actually double the number of transactions we are allowed. We want one row for each transaction to capture the state after we have bought for this transaction, and one row for before buying as part of this transaction (we’ll refer to this row as “pre-bought” throughout).
We’ll order the rows so that earlier rows represent fewer transactions remaining. Thus the first row indicates the state of having purchased the stock for the final transaction, but not yet having sold it. The second row indicates you have one transaction still available, but you haven’t bought the stock for this transaction yet. The third row indicates you have purchased the stock and you’ll have 1 complete transaction remaining after selling it. And so on. So with n days and k transactions, our grid will have size 2k x n.
Base Cases
Now let’s think about the base cases of this grid. It is easiest to consider the last day, the final column of the grid. If we’re on the last day, the marginal gain we can make if we are holding the stock is simply to sell it (all prices are positive), which would give us a “profit” of the final sale price. We don’t need to consider the cost of buying the stock for these rows. We just think about “given that I have the stock, what’s the most I can end up with”.
Then, for all the “pre-bought” rows, the final column is 0. We don’t have enough time to buy AND sell a stock, so we just do nothing.
Now we can also populate the rows for the final transaction fairly easily. These are base cases as well. We’ll populate them from right to left, meaning from the later days to the earlier days (recall we’ve already filled in the very last day).
For the “top” row, where we’ve already bought the stock for our final transaction, we have two choices. We can “sell” the stock on that day, or “keep” the stock to cell later. The first option means we just use the price for that day, and the second means we use the recorded value for the next day. We want the maximum of these options.
Once we’ve populated the “bought” row, we move on to the “pre-bought” row below it. Again, we’ll loop right to left and have two options each time. We can “buy” the stock, which would move us “up” to the bought row on the next day, except we have to subtract the price of the stock. Or we can “stay” and not buy the stock. This means we grab the value from the same row in the next column. Again, we just use the max of these two options.
At this point, we’ve populated the entire last column of our grid AND the first two rows.
Recursive Cases
For the “recursive” cases (we can actually think of them as “inductive” cases), we go two rows at a time, counting up to our total transaction count. Each transaction follows the same pattern, which is similar to what we did for the rows above.
First, fill in the “bought” row for this transaction. We can “sell” or “keep” the stock. Selling moves us up and to the right, and adds the sale price for that day. But keeping moves us directly right in our grid. Again, we take the max of these options.
Then we fill the “pre-bought” row for this transaction. We can “buy” or “stay”. Buying means subtracting the price for that day from the value up and to the right. Staying means we take the value immediately to our right. As always, take the max.
When we’ve completed populated our grid following this pattern, our final answer is the value in the bottom left of the grid! This is the maximum profit starting from day 0 and before buying for any of our transactions, which is the true starting state of the problem.
Rust Solution
Let’s solve this in Rust first! We begin by defining a few values and handling an edge case (if there’s only 1 day, the answer is 0 since we can’t buy and sell).
pub fn max_profit(k: i32, prices: Vec<i32>) -> i32 {
let n = prices.len();
if n == 1 {
return 0;
}
let ku = k as usize;
let numRows = 2 * ku;
// Create our zero-ed out grid
let mut dp: Vec<Vec<i32>> = Vec::with_capacity(numRows);
dp.resize(numRows, Vec::with_capacity(n));
for i in 0..numRows {
dp[i].resize(n, 0);
}
...
}
Now we handle the first two rows (our “final” transaction). In each case, we start with the base case of the final day, and then move from right to left, following the rules described in the algorithm.
pub fn max_profit(k: i32, prices: Vec<i32>) -> i32 {
...
// Final Transaction
// Always sell on the last day!
dp[0][n - 1] = prices[n - 1];
for i in (0..=(n-2)).rev() {
// Sell or Keep
dp[0][i] = std::cmp::max(prices[i], dp[0][i+1]);
}
dp[1][n - 1] = 0;
for i in (0..=(n-2)).rev() {
// Buy (subtract price!) or keep
dp[1][i] = std::cmp::max(dp[0][i+1] - prices[i], dp[1][i+1]);
}
...
}
Now we write our core loop, going through the remaining transaction count. We start by defining the correct row numbers and setting the final-column base cases:
pub fn max_profit(k: i32, prices: Vec<i32>) -> i32 {
// Setup
...
// Final Transaction
...
// All other transactions
for j in 1..ku {
let boughtRow = 2 * j;
let preBoughtRow = boughtRow + 1;
// Always sell on the last day!
dp[boughtRow][n - 1] = prices[n - 1];
// 0 - No time to buy/sell!
dp[preBoughtRow][n - 1] = 0;
...
}
}
And now we apply the logic for our algorithm. As we populate each row from right to left, we simply apply our two choices: sell/keep for the “bought” row and buy/stay for the “pre-bought” row.
pub fn max_profit(k: i32, prices: Vec<i32>) -> i32 {
...
// All other transactions
for j in 1..ku {
let boughtRow = 2 * j;
let preBoughtRow = boughtRow + 1;
// Always sell on the last day!
dp[boughtRow][n - 1] = prices[n - 1];
// 0 - No time to buy/sell!
dp[preBoughtRow][n - 1] = 0;
// Sell or Keep!
for i in (0..=(n-2)).rev() {
dp[boughtRow][i] = std::cmp::max(dp[boughtRow - 1][i+1] + prices[i], dp[boughtRow][i + 1]);
}
// Buy or Stay!
for i in (0..=(n-2)).rev() {
dp[preBoughtRow][i] = std::cmp::max(dp[boughtRow][i+1] - prices[i], dp[preBoughtRow][i + 1])
}
}
return dp[numRows - 1][0];
}
This completes our loop, and the final thing we need, as you can see, is to return the value in the bottom left of our grid!
Here is the complete solution:
pub fn max_profit(k: i32, prices: Vec<i32>) -> i32 {
let n = prices.len();
if n == 1 {
return 0;
}
let ku = k as usize;
let numRows = 2 * ku;
let mut dp: Vec<Vec<i32>> = Vec::with_capacity(numRows);
dp.resize(numRows, Vec::with_capacity(n));
for i in 0..numRows {
dp[i].resize(n, 0);
}
// Final Transaction
dp[0][n - 1] = prices[n - 1];
for i in (0..=(n-2)).rev() {
dp[0][i] = std::cmp::max(prices[i], dp[0][i+1]);
}
dp[1][n - 1] = 0;
for i in (0..=(n-2)).rev() {
dp[1][i] = std::cmp::max(dp[0][i+1] - prices[i], dp[1][i+1]);
}
// All other transactions
for j in 1..ku {
let boughtRow = 2 * j;
let preBoughtRow = boughtRow + 1;
dp[boughtRow][n - 1] = prices[n - 1];
dp[preBoughtRow][n - 1] = 0;
for i in (0..=(n-2)).rev() {
dp[boughtRow][i] = std::cmp::max(dp[boughtRow - 1][i+1] + prices[i], dp[boughtRow][i + 1]);
}
for i in (0..=(n-2)).rev() {
dp[preBoughtRow][i] = std::cmp::max(dp[boughtRow][i+1] - prices[i], dp[preBoughtRow][i + 1])
}
}
return dp[numRows - 1][0];
}
Haskell Solution
As we saw in our first DP problem, we often don’t need as much memory as it initially seems. We filled out the “whole grid” for Rust, which helps make the algorithm more clear. But our Haskell solution will reflect the fact that we only actually need to pass along one preceding row (the pre-bought row) each time we loop through a transaction.
Let’s start by defining our edge case, as well as a few useful terms. We’ll define our indices in left-to-right order, but in all cases we’ll loop through them in reverse with foldr:
maxProfit :: V.Vector Int -> Int -> Int
maxProfit nums k = if n == 1 then 0
else ...
where
n = V.length nums
lastPrice = nums V.! (n - 1)
idxs = ([0..(n-2)] :: [Int])
...
Now we’ll define three different “loop” functions, all with the same pattern. We’ll use an IntMap Int to represent each “row” in our grid. So these functions will modify the IntMap for the row as we go along, while taking the new “index” we are populating. Let’s start with the base case, the first “bought” row, corresponding to our final transaction.
It will give us two options: sell or keep, following our algorithm. We insert the max of these into the map.
maxProfit :: V.Vector Int -> Int -> Int
maxProfit nums k = if n == 1 then 0
else ...
where
n = V.length nums
lastPrice = nums V.! (n - 1)
idxs = ([0..(n-2)] :: [Int])
ibFold :: Int -> IM.IntMap Int -> IM.IntMap Int
ibFold i mp =
let sell = nums V.! i
keep = mp IM.! (i + 1)
in IM.insert i (max sell keep) mp
initialBought = foldr ibFold (IM.singleton (n-1) lastPrice) idxs
...
We construct our initialBought row by folding, starting with a singleton of the last column base case.
Now we’ll write a function that, given a “bought” row, can construct the preceding “pre-bought” row. This will apply the “buy” and “stay” ideas in our algorithm and select between them. Choosing the “buy” option requires looking into the preceding “bought” row, while “stay” looking into a later index of the existing map:
maxProfit :: V.Vector Int -> Int -> Int
maxProfit nums k = if n == 1 then 0
else ...
where
...
initialBought = foldr ibFold (IM.singleton (n-1) lastPrice) idxs
preBoughtFold :: IM.IntMap Int -> Int -> IM.IntMap Int -> IM.IntMap Int
preBoughtFold bought i preBought =
let buy = bought IM.! (i+1) - nums V.! i
stay = preBought IM.! (i+1)
in IM.insert i (max buy stay) preBought
initialPreBought = foldr (preBoughtFold initialBought) (IM.singleton (n-1) 0) idxs
We construct the initialPreBought row by applying this function with initialBought as the input. But we’ll use this for the rest of our “pre-bought” rows as well! First though, we need a more general loop for the rest of our “bought” rows.
This function has the same structure as pre-bought, just applying the “sell” and “keep” rules instead of “buy” and “stay”:
maxProfit :: V.Vector Int -> Int -> Int
maxProfit nums k = if n == 1 then 0
else ...
where
...
boughtFold :: IM.IntMap Int -> Int -> IM.IntMap Int -> IM.IntMap Int
boughtFold preBought i bought =
let sell = preBought IM.! (i+1) + nums V.! i
keep = bought IM.! (i+1)
in IM.insert i (max sell keep) bought
Now we’re ready for our core loop! This will loop through every transaction except the base case. It takes only the preceding “pre-bought” row and the transaction counter. Once the counter reaches k, we return the first value in this row. Otherwise, we run the “bought” loop to produce a next “bought” row, and we pass this in to the “pre-bought” loop to produce a new “pre-bought” row. This becomes the input to our recursive call:
maxProfit :: V.Vector Int -> Int -> Int
maxProfit nums k = if n == 1 then 0
else loop 1 initialPreBought
where
...
loop :: Int -> IM.IntMap Int -> Int
loop i preBought = if i >= k then preBought IM.! 0
else
let bought' = foldr (boughtFold preBought) (IM.singleton (n-1) lastPrice) idxs
preBought' = foldr (preBoughtFold bought') (IM.singleton (n-1) 0) idxs
in loop (i + 1) preBought'
As you can see above, we complete the solution by calling our loop with the initial “pre-bought” row, and a transaction counter of 1!
Here’s our full Haskell solution:
maxProfit :: V.Vector Int -> Int -> Int
maxProfit nums k = if n == 1 then 0
else loop 1 initialPreBought
where
n = V.length nums
lastPrice = nums V.! (n - 1)
idxs = ([0..(n-2)] :: [Int])
ibFold :: Int -> IM.IntMap Int -> IM.IntMap Int
ibFold i mp =
let sell = nums V.! i
keep = mp IM.! (i + 1)
in IM.insert i (max sell keep) mp
initialBought = foldr ibFold (IM.singleton (n-1) lastPrice) idxs
preBoughtFold :: IM.IntMap Int -> Int -> IM.IntMap Int -> IM.IntMap Int
preBoughtFold bought i preBought =
let buy = bought IM.! (i+1) - nums V.! i
stay = preBought IM.! (i+1)
in IM.insert i (max buy stay) preBought
initialPreBought = foldr (preBoughtFold initialBought) (IM.singleton (n-1) 0) idxs
boughtFold :: IM.IntMap Int -> Int -> IM.IntMap Int -> IM.IntMap Int
boughtFold preBought i bought =
let sell = preBought IM.! (i+1) + nums V.! i
keep = bought IM.! (i+1)
in IM.insert i (max sell keep) bought
loop :: Int -> IM.IntMap Int -> Int
loop i preBought = if i >= k then preBought IM.! 0
else
let bought' = foldr (boughtFold preBought) (IM.singleton (n-1) lastPrice) idxs
preBought' = foldr (preBoughtFold bought') (IM.singleton (n-1) 0) idxs
in loop (i + 1) preBought'
Conclusion
That’s the last LeetCode solution we’re going to write for now! Hopefully you’ve got a good impression now on the differences between dynamic programming in Haskell when compared to a language like Rust. To learn more about the basics of these Haskell solutions, take a look at our course, Solve.hs! You’ll also get a ton of practice with hundreds of exercise problems in the course!
Next week we will switch gears, and start working on some interesting parsing problems!
Spatial DP: Finding the Largest Square
In the past two weeks we’ve explored a couple different problems in dynamic programming. These were simpler 1-dimensional problems. But dynamic programming is often at its most powerful when you can work across multiple dimensions. In today’s problem, we’ll consider a problem that is actually a 2D spatial problem where we can use dynamic programming.
If you want to learn how to write dynamic programming solutions in Haskell from the ground up, take a look at our Solve.hs course. DP is one of several algorithmic approaches you’ll learn in Module 3!
The Problem
Today’s problem (Maximal Square) is fairly simple conceptually. We are given a grid of 1’s and 0’s like so:
10100
11111
00111
10101
We must return the size of the largest square in the grid composed entirely of 1’s. So in the example above, the answer would be 4. There are two 2x2 squares we can form, starting in the 2nd row, using either the 3rd or 4th column as the “top left” of the square.
We can do a couple small edits to change the answer here. For example, we can flip the second ‘0’ in the bottom row and we’ll get a 3x3 grid, allowing the answer 9:
10100
11111
00111
10111
We could instead flip the second ‘1’ in the third row, and now the answer is only 1, as there are no 2x2 squares remaining:
10100
11111
00101
10101
The Algorithm
To solve this, we can imagine a DP grid that has “layers” where each layer has the same dimensions as the original grid. Each layer has a number “k” associated with it. The index {row,column} at layer k tells us whether or not a square of size k exists in the original grid with size k x k, with the cell {row, column} as its top left cell.
To construct this grid, we would need a base case and a recursive case. The base case is to consider layer 1. This is identical to the original grid we receive. Any location with 1 in the original grid is the top left for a 1x1 square.
So how do we build the layer k+1? This requires one simple insight. Suppose we are dealing with a single index {r,c}. In order for this to be the top left of a square of size k+1, we just need to check that 4 cells begin squares of size k: {r,c}, {r+1,c}, {r,c+1},{r+1,c+1}.
So to form the next layer, we just loop through each index in the layer and fill it in with 1 if it meets that criterion. Once we reach a layer where each entry is 0, we are done. We should return the square of the last layer we found.
There are a few optimizations possible here. Thinking back to our first DP problem, we didn’t need to store the full DP array since each new step only depended on a couple prior values. This time, we don’t need a full grid with “k” layers. We could alternate with only two grids, saving new values from the prior grid, and then making our “new” grid the “old” grid for the next layer.
But even simpler than that, we can keep modifying a single grid in place. Each “new” value we calculate depends on numbers below and/or to its right. So as long as we loop through the grid from left to right and top to bottom, we are safe modifying its values in place. At least, that’s what we’ll do in Rust. In Haskell we could do this with the mutable array API, but we’ll stick with the more conventional, immutable, approach in this article. (You can learn more about Haskell’s mutable arrays in Solve.hs).
Rust Solution
Let’s start with the Rust solution, demonstrating the mutable array approach. We’ll start by defining a series of terms, like the dimensions of our input and our dp grid (which is initially a clone of the input). We’ll also define a boolean (found) to indicate if we’ve found at least a single 1 on the current layer. We’ll track level, the number of layers confirmed to have a 1.
pub fn maximal_square(matrix: Vec<Vec<char>>) -> i32 {
let m = matrix.len();
let n = matrix[0].len();
let mut level = 0;
let mut dp = matrix.clone();
let mut found = true;
...
return level * level;
}
Of course, our final answer is just the square of the final “level” we determine. But how do we find this? We’ll need an outer while loop that terminates once we hit a level that does not hold a 1. We reset found as false to start each loop, but at the end of the loop, we’ll increment the level if we have found something.
pub fn maximal_square(matrix: Vec<Vec<char>>) -> i32 {
let m = matrix.len();
let n = matrix[0].len();
let mut level = 0;
let mut dp = matrix.clone();
let mut found = true;
while (found) {
found = false;
...
if (found) {
level += 1;
}
}
return level * level;
}
Now the core of the “layer” loop is to loop through each cell, left to right and top to bottom.
pub fn maximal_square(matrix: Vec<Vec<char>>) -> i32 {
...
while (found) {
found = false;
for i in 0..m {
for j in 0..n {
...
}
}
if (found) {
level += 1;
}
}
return level * level;
}
So what happens inside the loop? When we hit a 0 cell, we don’t need to do anything. It always remains a 0 and we haven’t “found” anything. But interesting things happen if we hit a 1.
First, we note that found is now true - this layer is not empty. We have found a k x k square. But second, we should now reset this cell as 0 if it does not make a square of size k+1. We need to first check the dimensions to make sure we don’t go out of bounds, but then also check the 3 spaces, to the right, below, and diagonally away from us. If any of these are 0, we reset this cell as 0.
pub fn maximal_square(matrix: Vec<Vec<char>>) -> i32 {
...
while (found) {
found = false;
for i in 0..m {
for j in 0..n {
if (dp[i][j] == '1') {
found = true;
if (i + 1 >= m ||
j + 1 >= n ||
dp[i][j+1] == '0' ||
dp[i+1][j] == '0' ||
dp[i+1][j+1] == '0') {
dp[i][j] = '0';
}
}
}
}
if (found) {
level += 1;
}
}
return level * level;
}
And just by filling in this logic, our function is suddenly done! Our inner loop is complete, and our outer loop will break once we find no more increasingly large squares. Here is the full Rust solution:
pub fn maximal_square(matrix: Vec<Vec<char>>) -> i32 {
let m = matrix.len();
let n = matrix[0].len();
let mut level = 0;
let mut dp = matrix.clone();
let mut found = true;
while (found) {
found = false;
for i in 0..m {
for j in 0..n {
if (dp[i][j] == '1') {
found = true;
if (i + 1 >= m ||
j + 1 >= n ||
dp[i][j+1] == '0' ||
dp[i+1][j] == '0' ||
dp[i+1][j+1] == '0') {
dp[i][j] = '0';
}
}
}
}
if (found) {
level += 1;
}
}
return level * level;
}
Haskell Solution
Now let’s write this in Haskell. We’ll start with a few definitions, including a type alias for our DP map. We’ll take an Array as the problem input, but we want a HashMap for our stateful version since we can “mutate” a HashMap efficiently:
type SquareMap = HM.HashMap (Int, Int) Bool
maximalSquare :: A.Array (Int, Int) Bool -> Int
maximalSquare grid = ...
where
((minRow,minCol), (maxRow, maxCol)) = A.bounds grid
initialMap = HM.fromList [vs | vs <- A.assocs grid]
...
Now we’ll define two loop functions - one for the inner loop, one for the outer loop. The “state” for the inner loop is our current level number, as well as the map of the previous layer. The inner loop (coordLoop) should return us an updated map, as well as the found bool value telling us if we’ve found at least a single 1 in the prior layer.
maximalSquare :: A.Array (Int, Int) Bool -> Int
maximalSquare grid = ...
where
((minRow,minCol), (maxRow, maxCol)) = A.bounds grid
initialMap = HM.fromList [vs | vs <- A.assocs grid]
coordLoop :: (Bool, SquareMap) -> (Int, Int) -> (Bool, SquareMap)
coordLoop (found, mp) coord@(r, c) = ...
loop :: Int -> HM.HashMap (Int, Int) Bool -> Int
loop level mp = ...
...
Notice that coordLoop has the argument pattern for foldl, rather than foldr. We want to loop through our coordinates in the proper order, from left to right and top down. If we use a right fold over the indices of the grid, it will go in reverse order.
Let’s start by filling in the inner loop. The first thing to do is determine if the found value needs to change. This is the case if we discover a True value at this index:
maximalSquare :: A.Array (Int, Int) Bool -> Int
maximalSquare grid = ...
where
((minRow,minCol), (maxRow, maxCol)) = A.bounds grid
initialMap = HM.fromList [vs | vs <- A.assocs grid]
coordLoop :: (Bool, SquareMap) -> (Int, Int) -> (Bool, SquareMap)
coordLoop (found, mp) coord@(r, c) =
let found' = found || mp HM.! coord
...
in (found’, ...)
Now we need the 5 conditions that tell us if this cell should get cleared. Calculate all these, and insert False at the cell if any of them match. Otherwise, keep the map as is!
maximalSquare :: A.Array (Int, Int) Bool -> Int
maximalSquare grid = ...
where
((minRow,minCol), (maxRow, maxCol)) = A.bounds grid
initialMap = HM.fromList [vs | vs <- A.assocs grid]
coordLoop :: (Bool, SquareMap) -> (Int, Int) -> (Bool, SquareMap)
coordLoop (found, mp) coord@(r, c) =
let found' = found || mp HM.! coord
tooRight = c >= maxCol
tooLow = r >= maxRow
toRight = mp HM.! (r, c + 1)
under = mp HM.! (r + 1, c)
diag = mp HM.! (r + 1, c + 1)
failNext = tooLow || tooRight || not toRight || not under || not diag
mp' = if failNext then HM.insert coord False mp else mp
in (found', mp')
...
Now for the outer loop, we use foldl to go through our coordinates using the coordLoop. If we’ve found at least 1 square at this size, then we recurse with the new map and an incremented size. Otherwise we return the square of the current level. Then we just need to call this loop with initial values:
```haskell
type SquareMap = HM.HashMap (Int, Int) Bool
maximalSquare :: A.Array (Int, Int) Bool -> Int
maximalSquare grid = loop 0 initialMap
where
((minRow,minCol), (maxRow, maxCol)) = A.bounds grid
initialMap = HM.fromList [vs | vs <- A.assocs grid]
coordLoop :: (Bool, SquareMap) -> (Int, Int) -> (Bool, SquareMap)
coordLoop (found, mp) coord@(r, c) = ...
loop :: Int -> HM.HashMap (Int, Int) Bool -> Int
loop level mp =
let (found, mp') = foldl coordLoop (False, mp) (A.indices grid)
in if found then loop (level + 1) mp' else (level * level)
This completes our Haskell solution!
type SquareMap = HM.HashMap (Int, Int) Bool
maximalSquare :: A.Array (Int, Int) Bool -> Int
maximalSquare grid = loop 0 initialMap
where
((minRow,minCol), (maxRow, maxCol)) = A.bounds grid
initialMap = HM.fromList [vs | vs <- A.assocs grid]
coordLoop :: (Bool, SquareMap) -> (Int, Int) -> (Bool, SquareMap)
coordLoop (found, mp) coord@(r, c) =
let found' = found || mp HM.! coord
tooRight = c >= maxCol
tooLow = r >= maxRow
toRight = mp HM.! (r, c + 1)
under = mp HM.! (r + 1, c)
diag = mp HM.! (r + 1, c + 1)
failNext = tooLow || tooRight || not toRight || not under || not diag
mp' = if failNext then HM.insert coord False mp else mp
in (found', mp')
loop :: Int -> HM.HashMap (Int, Int) Bool -> Int
loop level mp =
let (found, mp') = foldl coordLoop (False, mp) (A.indices grid)
in if found then loop (level + 1) mp' else (level * level)
Conclusion
Next week we’ll look at one more multi-dimensional DP problem where the dimensions aren’t quite as obvious in this spatial way. The best way to understand DP is to learn related concepts from scratch, including your basic use-it-or-lose-it problems and memoization. You’ll study all these concepts and learn Haskell implementation tricks in Module 3 of Solve.hs. Enroll in the course now!
Making Change: Array-Based DP
Today we’ll continue the study of Dynamic Programming we started last week. Last week’s problem let us use a very compact memory footprint, only remember a couple prior values. This week, we’ll study a very canonical DP problem that really forces us to store a longer array of prior values to help us populate the new solutions.
For an in-depth study of Dynamic Programming in Haskell and many other problem solving techniques, take a look at our Solve.hs course today! Module 3 focuses on algorithms, and introduces several steps leading up to understanding DP.
The Problem
Our problem today is Coin Change, and it’s relatively straightforward. We are given a list of coin values, and an amount to make change for. We want to find the smallest number of coins we can use to provide the given amount (or -1 if the amount cannot be made from the coins we have).
So for example, if we have coins [1,2,5], and we are trying to make 11 cents of change, the answer is 3 coins, because we take 2 5 coins and 1 1 coin.
If we have the coins [2,10,15] and the amount is 13, we should return -1, since there is no way to make 13 cents from these coins.
The Algorithm
Let us first observe that a greedy algorithm does not work here! We can’t simply take the largest coin under the remaining amount and then recurse. If we have coins like [1, 20, 25] and the amount is 40, we can do this with 2 coins (both 20), but taking a 25 coin to start is suboptimal.
The way we will do this is to build a DP array so that index i represents the fewest coins necessary to produce the amount i. All values are initially -1, to indicate that we might not be able to satisfy the number. However, we can set the index 0 as 0, since no coins are needed to give 0 cents.
So we have our base case, but how do we fill in index i, assuming we’ve filled in everything up to i - 1? The answer is that we will consider each coin we can use, and look back in the array based on its value. So if 5 is one of our coins, we’ll consider just adding 1 to the value at index i - 5. We’ll take the minimum value based on looking at all the different coin options, being careful to observe edge cases where no values are possible.
Unlike the last problem, this does require us to keep a larger array of values. We’re not just reaching back for the prior value in our array, we’re considering values that are much further back. Plus the amount of look-back we need is dynamic depending on the problem inputs.
We’ll write a solution where the array has size equal to the given amount (plus 1). It would be possible to instead use a structure whose size simply covers the range of possible coin values, but this becomes considerably more difficult.
Rust Solution
We’ll start with the Rust solution, since modifying arrays is more natural in Rust. What is unnatural in Rust is mixing integer types. Everything has to be usize if we’re going to index into arrays with it, so let’s start by converting the amount and the coins into usize:
```rust
pub fn coin_change(coins: Vec<i32>, amount: i32) -> i32 {
let n = amount as usize;
let cs: Vec<usize> = coins.into_iter().map(|x| x as usize).collect();
...
}
Now we’ll initialize our dp array. It should have a size equal to the amount plus 1 (we want indices 0 and amount to be valid). Most cells should initially be -1, but we’ll make the 0 index equal to 0 as our base case (no coins to make 0 cents of change). We’ll also return the final value from this array as our answer.
pub fn coin_change(coins: Vec<i32>, amount: i32) -> i32 {
let n = amount as usize;
let cs: Vec<usize> = coins.into_iter().map(|x| x as usize).collect();
let mut dp = Vec::with_capacity(n + 1);
dp.resize(n + 1, -1);
dp[0] = 0;
...
return dp[n];
}
Let’s set up our loops. We go through all the indices from 1 to amount, and loop through all the coins for each index.
pub fn coin_change(coins: Vec<i32>, amount: i32) -> i32 {
let n = amount as usize;
let cs: Vec<usize> = coins.into_iter().map(|x| x as usize).collect();
let mut dp = Vec::with_capacity(n + 1);
dp.resize(n + 1, -1);
dp[0] = 0;
for i in 1..=n {
for coin in &cs {
...
}
}
return dp[n];
}
Now let’s apply some rules for dealing with each coin. First, if the coin is larger than the index, we do nothing, since we can’t use it for this amount. Otherwise, we try to use it. We get a “previous” value for this coin, meaning we look at our dp table going back the number of spaces corresponding to the coin’s value.
pub fn coin_change(coins: Vec<i32>, amount: i32) -> i32 {
...
for i in 1..=n {
for coin in &cs {
if *coin <= i {
let prev = dp[i - coin];
...
}
}
}
return dp[n];
}
If the prior value is -1, we can ignore it. This means we can’t actually use this coin to form the value at this index. Otherwise, we look at the current value in the dp table for this index. We may have a value here from previous coins already. If this value is not -1, and it is larger than the value we get from using this new coin, we replace the value in the dp table:
pub fn coin_change(coins: Vec<i32>, amount: i32) -> i32 {
let n = amount as usize;
let cs: Vec<usize> = coins.into_iter().map(|x| x as usize).collect();
let mut dp = Vec::with_capacity(n + 1);
dp.resize(n + 1, -1);
dp[0] = 0;
for i in 1..=n {
for coin in &cs {
if *coin <= i {
let prev = dp[i - coin];
if prev != -1 && (dp[i] == -1 || prev + 1 < dp[i]) {
dp[i] = prev + 1;
}
}
}
}
return dp[n];
}
And this completes our solution!
Haskell Solution
In Haskell, immutability makes DP with arrays a bit more challenging. We could use mutable arrays, but these are a little tricky (you can learn about them in Solve.hs).
Instead we’ll learn on the IntMap type, which is just like Data.Map but always uses Int for keys. This structure is “mutable” in the same way as other map-like structures in Haskell. We’ll write a core loop that takes this map as its stateful input, as well as the index:
import qualified Data.IntMap.Lazy as IM
coinChange :: [Int] -> Int -> Int
coinChange coins amount = ...
where
loop :: IM.IntMap Int -> Int -> Int
loop dp i = ...
A notable difference with how we’ll use our map is that we don’t have entries for invalid indices. These will be absent, and we’ll use fromMaybe with our map to consider that they might not exist. As a first example of this, let’s do the base case for our loop. Once the index i exceeds our amount, we’ll return the value in our map at amount, or -1 if it doesn’t exist:
coinChange :: [Int] -> Int -> Int
coinChange coins amount = ...
where
loop :: IM.IntMap Int -> Int -> Int
loop dp i = if i > amount then fromMaybe (-1) (IM.lookup amount dp)
else ...
Now we need to loop through the coins while updating our IntMap. Hopefully you can guess what’s coming. We need to define a function that ends with a -> b -> b, where a is the new coin we’re processing and b is the IntMap. Then we can loop through the coins with foldr. This function will also take our current index, which will be constant across the loop of coins:
coinChange :: [Int] -> Int -> Int
coinChange coins amount = ...
where
coinLoop :: Int -> Int -> IM.IntMap Int -> IM.IntMap Int
coinLoop i coin dp = ...
loop :: IM.IntMap Int -> Int -> Int
loop dp i = if i > amount then fromMaybe (-1) (IM.lookup amount dp)
else ...
We consider the “previous” value, which we call -1 if it doesn’t exist. We also consider the “current” value for index i, but we use maxBound if it doesn’t exist. This is because we want to insert a new number if it’s smaller, and maxBound will always be larger:
coinChange :: [Int] -> Int -> Int
coinChange coins amount = ...
where
coinLoop :: Int -> Int -> IM.IntMap Int -> IM.IntMap Int
coinLoop i coin dp =
let prev = fromMaybe (-1) (IM.lookup (i - coin) dp)
current = fromMaybe maxBound (IM.lookup i dp)
in ...
If the prior value doesn’t exist, or if the existing value is smaller than using the previous value (plus 1), then we keep dp the same. Otherwise we insert the new value at this index:
coinChange :: [Int] -> Int -> Int
coinChange coins amount = ...
where
coinLoop :: Int -> Int -> IM.IntMap Int -> IM.IntMap Int
coinLoop i coin dp =
let prev = fromMaybe (-1) (IM.lookup (i - coin) dp)
current = fromMaybe maxBound (IM.lookup i dp)
in if prev == (-1) || current < prev + 1 then dp
else IM.insert i (prev + 1) dp
loop :: IM.IntMap Int -> Int -> Int
loop dp i = if i > amount then fromMaybe (-1) (IM.lookup amount dp)
else ...
Now to complete our function, we just have to invoke these two loops. The primary loop we invoke with a base map assigning 0 to 0. The secondary loop relies on foldr and looping over the coins. We use this result in our recursive call:
coinChange :: [Int] -> Int -> Int
coinChange coins amount = loop (IM.singleton 0 0) 1
where
coinLoop :: Int -> Int -> IM.IntMap Int -> IM.IntMap Int
coinLoop i coin dp =
let prev = fromMaybe (-1) (IM.lookup (i - coin) dp)
current = fromMaybe maxBound (IM.lookup i dp)
in if prev == (-1) || current < prev + 1 then dp
else IM.insert i (prev + 1) dp
loop :: IM.IntMap Int -> Int -> Int
loop dp i = if i > amount then fromMaybe (-1) (IM.lookup amount dp)
else loop (foldr (coinLoop i) dp coins) (i + 1)
And now we’re done!
Conclusion
Our first two problems have been simple, 1-dimensional DP problems. But DP really shines as a technique when applied across multiple dimensions. In the next two weeks we’ll consider some of these multi-dimensional DP problems.
For more practice with DP and other algorithms, sign up for Solve.hs, our Haskell problem solving course! The course has hundreds of practice problems so you can hone your Haskell skills!
Dynamic Programming Primer
We’re about to start our final stretch of Haskell/Rust LeetCode comparisons (for now). In this group, we’ll do a quick study of some dynamic programming problems, which are a common cause of headache on programming interviews. We’ll do a couple single-dimension problems, and then show DP in multiple dimensions. Haskell has a couple interesting quirks to work out with dynamic programming, so we’ll try to understand that by comparison to Rust.
Dynamic programming is one of a few different algorithms you’ll learn about in Module 3 of Solve.hs, our Haskell problem solving course. Check it out today!
The Problems
Today’s problem is called House Robber. Normally we wouldn’t want to encourage crime, but when people have such a convoluted security set up as this problem suggests, perhaps they can’t complain.
The idea is that we receive a list of integers, representing the value that we can gain from “robbing” each house on a street. The security system is set up so that the police will be alerted if and only if two adjacent houses are robbed. So we could rob every other house, and no police will come.
We are trying to determine the maximum value we can get from robbing these houses without setting off the alarm (by robbing adjacent houses).
Dynamic Programming Introduction
As mentioned in the introduction, we’ll solve this problem using dynamic programming. This term can mean a couple different things, but by and large the idea is that we use answers on smaller portions of the input (going down as far as base cases) to build up to the final answer on the full input.
This can be done from the top down, generally by means of recursion. If we do this, we’ll often want to “cache” answers (Memoization) to certain parts of the problem so we don’t do redundant calculations.
We can also build answers from the bottom up (tabulation). This often takes the form of creating an array and storing answers to smaller queries in this array. Then we loop from the start of this array to the end, which should give us our answer. Our solutions in this series will largely rest on this tabulation idea. However, as we’ll see, we don’t always need an array of all prior answers to do this!
The key in dynamic programming is to define, whether for an array index or a recursive call, exactly what a “partial” solution means. This will help us use partial solutions to build our complete solution.
The Algorithm
Now let’s figure out how we’ll use dynamic programming for our house robbing problem. The broad idea is that we could define two arrays, the “robbed” array and the “unrobbed” array. Each of these should be equal in size to the number of houses on the street. Let’s carefully define what each array means.
Index i of the “robbed” array should reflect the maximum value we can get from the houses [0..i] such that we have “robbed” house i. Then the “unrobbed” array, at index i, contains the maximum total value we can get from the houses [0..i] such that we have not robbed house i.
When it comes to populating these arrays we need to think first about the base cases. Then we need to consider how to build a new case from existing cases we have. With a recursive solution we have the same pattern: base case and recursive case.
The first two indices for each can be trivially calculated; they are our base cases:
robbed[0] = input[0] // Rob house 0
robbed[1] = input[1] // Rob house 1
unrobbed[0] = 0 // Can’t rob any houses
unrobbed[1] = input[0] // Rob house 0
Now we need to build a generic case i, assuming that we have already calculated all the values from 0 to i - 1. To calculate robbed[i], we assume we are robbing house i, thus we add input[i]. If we are robbing house i we must not have robbed house i - 1, so we add this value to unrobbed[i - 1].
To calculate unrobbed[i], we have the option of whether or not we robbed house i - 1. It may be advantageous to skip two houses in a row! Consider an example like [100, 1, 1, 100]. So we take the maximum of unrobbed[i - 1] and robbed[i - 1].
This gives us our general case, and so at the end we simply select the maximum of robbed[n - 1] and unrobbed[n - 1].
We’ve been speaking in terms of arrays, but we can observe that we only need the i - 1 value from each array to construct the i values. This means we don’t actually have to store a complete array, which would take O(n) memory. Instead we can store the last “robbed” number and the last “unrobbed” number. This makes our solution O(1) memory.
Haskell Solution
Now let’s write some code, starting with Haskell! LeetCode guarantees that our input is non-empty, but we still need to handle the size-1 case specially:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
...
Now let’s write a recursive loop function that will take our prior two values (robbed and unrobbed) as well as the index. These are the “stateful” values of our loop. We’ll use these to either return the final value, or make a recursive call with new “robbed” and “unrobbed” values.
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = ...
For the “final” case, we see if we have reached the end of our array (i = n), in which case we return the max of the two values:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = if i == n then max lastRobbed lastUnrobbed
else ...
Now we fill in our recursive case, using the logic discussed in our algorithm:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = if i == n then max lastRobbed lastUnrobbed
else
let newRobbed = nums V.! i + lastUnrobbed
newUnrobbed = max lastRobbed lastUnrobbed
in loop (newRobbed, newUnrobbed) (i + 1)
Finally, we make the initial call to loop to get our answer! This completes our Haskell solution:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else loop (nums V.! 1, nums V.! 0) 2
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = if i == n then max lastRobbed lastUnrobbed
else
let newRobbed = nums V.! i + lastUnrobbed
newUnrobbed = max lastRobbed lastUnrobbed
in loop (newRobbed, newUnrobbed) (i + 1)
Even when tabulating from the ground up in Haskell, we can still use recursion!
Rust Solution
Our Rust solution is similar, just using a loop instead of a recursive function. We start by handling our edge case and coming up with the initial values for “last robbed” and “last unrobbed”.
pub fn rob(nums: Vec<i32>) -> i32 {
let n = nums.len();
if n == 1 {
return nums[0];
}
let mut lastRobbed = nums[1];
let mut lastUnrobbed = nums[0];
...
}
Now we just apply our algorithmic logic in a loop from 2 to n, resetting lastRobbed and lastUnrobbed each time.
pub fn rob(nums: Vec<i32>) -> i32 {
let n = nums.len();
if n == 1 {
return nums[0];
}
let mut lastRobbed = nums[1];
let mut lastUnrobbed = nums[0];
for i in 2..n {
let newRobbed = nums[i] + lastUnrobbed;
let newUnrobbed = std::cmp::max(lastUnrobbed, lastRobbed);
lastRobbed = newRobbed;
lastUnrobbed = newUnrobbed;
}
return std::cmp::max(lastRobbed, lastUnrobbed);
}
And now we’re done with Rust!
Conclusion
Next week we’ll do a problem that actually requires us to store a full array of prior solutions. To learn the different stages in building up an understanding of dynamic programming, you should take our problem solving course, Solve.hs. Module 3 focuses on algorithms, including dynamic programming!
Apply the Trie: Word Search
Today will be a nice culmination of some of the work we’ve been doing with data structures and algorithms. In the past few weeks we’ve covered graph algorithms, particularly Depth First Search. And last week, we implemented the Trie data structure from scratch. Today we’ll solve a “Hard” problem (according to LeetCode) that pulls these pieces together!
For a comprehensive study of data structures and algorithms in Haskell, you should take a look at our course, Solve.hs. You’ll spend a full module on data structures in Haskell, and then another module learning about algorithms, especially graph algorithms!
The Problem
Today’s problem is Word Search II. In the first version of this problem, LeetCode asks you to determine if we can find a single word in a grid. In this second version, we receive an entire list of words, and we have to return the subset of those words that can be found in the grid.
Now the “search” mechanism on this grid is not simply a straight line search. We’ll use a limited Boggle Search. From each letter in a word, we can make any movement to the next letter, as long as it is horizontal or vertical (Boggle also allows diagonal, but we won’t).
So here’s an example grid:
CAND
XBIY
TENQ
Words like “CAN” and “TEN” are obviously allowed. But we can also use “CAB”, even though it doesn’t form a single straight line. Even better, we can use “CABINET”, snaking through all 3 rows. However, a word like “DIN” is disallowed since it would require a diagonal move. Also, we cannot “re-use” letters in the same word. So “TENET” would not be allowed, as it requires backtracking over the E and T.
We need to find the most efficient way to determine which of our input words can be found in the grid.
The Algorithm
This problem combines two elements we’ve recently worked with. First, we will use DFS ideas to actually search through the grid from a particular starting location. Second, we will use a Trie to store all of our input words. This will help us determine if we can stop searching. Once we find a string that is not a prefix in our Trie, we can discontinue the search branch.
Here’s a run-down of the solution:
- Make a Trie from the Input Words
- Search from each starting location in the grid, trying to add good words to a growing set of results.
- At each location, add the character to our string. See if the resulting string is still a valid prefix of our Trie.
- If it is, add the word to our results if the Trie indicates it is a valid word. Then search the neighboring locations, while keeping track of locations we’ve visited.
- Once we no longer have a valid Trie, we can stop the line of searching
The obscures quite a few details, but from our last few weeks of work, those details shouldn’t be too difficult.
Updating Trie Functions
For both our solutions, we’ll assume we’re using the same Trie structure we built last week. We’ll need the general structure, as well as the insert function.
However, we won’t actually need the search or startsWith functions. Each time we call these, we traverse the full length of the string we’re querying. And the way our algorithm will work here, that will get quite inefficient (quadratic time overall).
Instead we’re going to rely on directly accessing sub-Tries, so that as we build our search word longer, we’ll get a Trie that assumes we’ve already searched that word in the main Trie. This will make subsequent searches faster.
To make this more convenient, we’ll just provide a function to get a Maybe Trie from the “sub-Tries” of the node we’re working with, based on the character. We’ll call this “popping” a Trie.
Here’s what it looks like in Rust:
impl Trie {
...
fn pop(&self, c: char) -> Option<&Trie> {
return self.nodes.get(&c);
}
}
And in Haskell:
popTrie :: Char -> Trie -> Maybe Trie
popTrie c (Trie _ subs) = M.lookup c subs
For completeness, here’s the entire Rust version of the Trie that we’ll need for this problem:
use std::collections::HashMap;
use std::str::Chars;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
fn insert(&mut self, word: String) {
self.insertIt(word.chars());
}
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
if !self.nodes.contains_key(&c) {
self.nodes.insert(c, Trie::new());
}
if let Some(subTrie) = self.nodes.get_mut(&c) {
subTrie.insertIt(iter);
}
} else {
self.endsWord = true;
}
}
fn pop(&self, c: char) -> Option<&Trie> {
return self.nodes.get(&c);
}
}
And here’s the full Haskell version:
data Trie = Trie Bool (M.Map Char Trie)
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
in (Trie ends (M.insert c newSub subs))
popTrie :: Char -> Trie -> Maybe Trie
popTrie c (Trie _ subs) = M.lookup c subs
Rust Solution
Now let’s move on to our solution, starting with Rust. As always with a graph problem, we’ll benefit from having a neighbors function. This will be very similar to functions we’ve written in the past few weeks, so we won’t dwell on it. In this case though, we’ll incorporate the visited set directly into this function, and exclude neighbors we’ve already seen:
pub fn neighbors(
nr: usize,
nc: usize,
visitedLocs: &HashSet<(usize, usize)>,
loc: (usize, usize)) -> Vec<(usize, usize)> {
let r = loc.0;
let c = loc.1;
let mut results = Vec::new();
if (r > 0 && !visitedLocs.contains(&(r - 1, c))) {
results.push((r - 1, c));
}
if (c > 0 && !visitedLocs.contains(&(r, c - 1))) {
results.push((r, c - 1));
}
if (r + 1 < nr && !visitedLocs.contains(&(r + 1, c))) {
results.push((r + 1, c));
}
if (c + 1 < nc && !visitedLocs.contains(&(r, c + 1))) {
results.push((r, c + 1));
}
return results;
}
Now, thinking back to the Islands example, we want to write a search function similar to the visit function we had before. The job of the visit function was to populate the visited set with all reachable tiles from the start. Our search function will populate a set of “results” with every word reachable from a certain location.
However, it will also require some immutable inputs, such as the board dimensions and the board itself. But it will also have several mutable, stateful items like the current Trie, the String we are building, and the current visited set. Here’s the signature we will use:
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
...
}
This function has two tasks. First, assess the current location to see if the word we completed by arriving here should be added, or if it’s at least a prefix of a remaining word. If either is true, our second job is to find this location’s neighbors and recursively call them, continuing to grow our string and search for longer words.
Let’s write the code for the first part. Naturally, we need the character at this grid location. Then we need to query our Trie to “pop” the sub-trie associated with this character. If this sub-Trie doesn’t exist, we immediately return. Otherwise, we consider this location “visited” (add it to the set) and we push the new character onto the string. If our new sub-Trie “ends a word”, then we add this word to our results set!
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
let c = board[loc.0][loc.1];
if let Some(subTrie) = trie.pop(c) {
currentStr.push(c);
visitedLocs.insert(loc);
if subTrie.endsWord {
seenWords.insert(currentStr.clone());
}
...
}
}
Now we get our neighbors and recursively search them, passing updated mutable values. But there’s one extra thing to include! After we are done searching our neighbors, we should “undo” our mutable changes to the visited set and the string.
We haven’t had to make this kind of “backtracking” change before. But we don’t want to permanently keep this location in this visited set, nor keep the string modified. When we return to our caller, we want these mutable values to be the same as how we got them. Otherwise, subsequent calls may be disturbed, and we’ll get incorrect answers!
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
let c = board[loc.0][loc.1];
if let Some(subTrie) = trie.pop(c) {
currentStr.push(c);
visitedLocs.insert(loc);
if subTrie.endsWord {
seenWords.insert(currentStr.clone());
}
let ns = neighbors(nr, nc, visitedLocs, loc);
for n in ns {
search(nr, nc, board, subTrie, n, visitedLocs, currentStr, seenWords);
}
// Backtrack! Remove this location and pop this character
visitedLocs.remove(&loc);
currentStr.pop();
}
}
This completes the search function. Now we just have to call it! We start our primary function by initializing our key values, especially our Trie. We need to insert all the starting words into a Trie that we create:
pub fn find_words(board: Vec<Vec<char>>, words: Vec<String>) -> Vec<String> {
let mut trie = Trie::new();
for word in &words {
trie.insert(word.to_string());
}
let mut results = HashSet::new();
let nr = board.len();
let nc = board[0].len();
...
}
And now we just loop through each location in the grid and search it as a starting location! For an extra optimization, we can stop our search early if we have found all of our words.
pub fn find_words(board: Vec<Vec<char>>, words: Vec<String>) -> Vec<String> {
let mut trie = Trie::new();
for word in &words {
trie.insert(word.to_string());
}
let mut results = HashSet::new();
let nr = board.len();
let nc = board[0].len();
for i in 0..nr {
for j in 0..nc {
if results.len() < words.len() {
let mut visited = HashSet::new();
let mut curr = String::new();
search(nr, nc, &board, &trie, (i, j), &mut visited, &mut curr, &mut results);
}
}
}
return results.into_iter().collect();
}
And we’re done! Here is our complete Rust solution!
pub fn neighbors(
nr: usize,
nc: usize,
visitedLocs: &HashSet<(usize, usize)>,
loc: (usize, usize)) -> Vec<(usize, usize)> {
let r = loc.0;
let c = loc.1;
let mut results = Vec::new();
if (r > 0 && !visitedLocs.contains(&(r - 1, c))) {
results.push((r - 1, c));
}
if (c > 0 && !visitedLocs.contains(&(r, c - 1))) {
results.push((r, c - 1));
}
if (r + 1 < nr && !visitedLocs.contains(&(r + 1, c))) {
results.push((r + 1, c));
}
if (c + 1 < nc && !visitedLocs.contains(&(r, c + 1))) {
results.push((r, c + 1));
}
return results;
}
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
let c = board[loc.0][loc.1];
if let Some(subTrie) = trie.pop(c) {
currentStr.push(c);
visitedLocs.insert(loc);
if subTrie.endsWord {
seenWords.insert(currentStr.clone());
}
let ns = neighbors(nr, nc, visitedLocs, loc);
for n in ns {
search(nr, nc, board, subTrie, n, visitedLocs, currentStr, seenWords);
}
visitedLocs.remove(&loc);
currentStr.pop();
}
}
pub fn find_words(board: Vec<Vec<char>>, words: Vec<String>) -> Vec<String> {
let mut trie = Trie::new();
for word in &words {
trie.insert(word.to_string());
}
let mut results = HashSet::new();
let nr = board.len();
let nc = board[0].len();
for i in 0..nr {
for j in 0..nc {
if results.len() < words.len() {
let mut visited = HashSet::new();
let mut curr = String::new();
search(nr, nc, &board, &trie, (i, j), &mut visited, &mut curr, &mut results);
}
}
}
return results.into_iter().collect();
}
Haskell Solution
Our Haskell solution starts with some of the same beats. We’ll create our initial Trie through insertion and define a familiar neighbors function:
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds board
trie = foldr insertTrie (Trie False M.empty) allWords
neighbors :: HS.HashSet (Int, Int) -> (Int, Int) -> [(Int, Int)]
neighbors visited (r, c) =
let up = if r > minRow && not (HS.member (r - 1, c) visited) then Just (r - 1, c) else Nothing
left = if c > minCol && not (HS.member (r, c - 1) visited) then Just (r, c - 1) else Nothing
down = if r < maxRow && not (HS.member (r + 1, c) visited) then Just (r + 1, c) else Nothing
right = if c < maxCol && not (HS.member (r, c + 1) visited) then Just (r, c + 1) else Nothing
in catMaybes [up, left, down, right]
...
Now let’s think about our search function. This function’s job is to update a set, given a particular location. We’re going to loop over this function with many locations. So we want the end of its signature to look like:
(Int, Int) -> HS.HashSet String -> HS.HashSet String
This will allow us to use it with foldr. But we still want to think about the mutable elements going into this function: the Trie, the visited set, and the accumulated string. These also change from call to call, but they should come earlier in the signature, since we can make them fixed over each loop of the function. So here’s what our type signature looks like:
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
The first order of business in this function is to “pop” the try based on the character at this location and see if it exists. If not, we simply return our original set:
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
...
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) -> ...
Now we update our current string and visited set, while adding the word to our results if the sub-Trie indicates we are at the end of a word:
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
...
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) ->
let currentStr' = board A.! loc : currentStr
visited' = HS.insert loc visited
seenWords' = if ends then HS.insert (reverse currentStr') seenWords else seenWords
...
And now we get our neighbors and loop through them with foldr. Observe how we define the function f that fixes the first three parameters with our new mutable values so we can cleanly call foldr.
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
...
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) ->
let currentStr' = board A.! loc : currentStr
visited' = HS.insert loc visited
seenWords' = if ends then HS.insert (reverse currentStr') seenWords else seenWords
ns = neighbors visited loc
f = search sub visited' currentStr'
in foldr f seenWords' ns
We’re almost done now! Having written the “inner” loop, we just have to write the “outer” loop that will loop through every location as a starting point.
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = HS.toList result
where
...
trie = foldr insertTrie (Trie False M.empty) allWords
result = foldr (search trie HS.empty "") HS.empty (A.indices board)
Here is our complete Haskell solution!
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = HS.toList result
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds board
trie = foldr insertTrie (Trie False M.empty) allWords
neighbors :: HS.HashSet (Int, Int) -> (Int, Int) -> [(Int, Int)]
neighbors visited (r, c) =
let up = if r > minRow && not (HS.member (r - 1, c) visited) then Just (r - 1, c) else Nothing
left = if c > minCol && not (HS.member (r, c - 1) visited) then Just (r, c - 1) else Nothing
down = if r < maxRow && not (HS.member (r + 1, c) visited) then Just (r + 1, c) else Nothing
right = if c < maxCol && not (HS.member (r, c + 1) visited) then Just (r, c + 1) else Nothing
in catMaybes [up, left, down, right]
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) ->
let currentStr' = board A.! loc : currentStr
visited' = HS.insert loc visited
seenWords' = if ends then HS.insert (reverse currentStr') seenWords else seenWords
ns = neighbors visited loc
f = search sub visited' currentStr'
in foldr f seenWords' ns
result = foldr (search trie HS.empty "") HS.empty (A.indices board)
Conclusion
This problem brought together a lot of interesting solution components. We applied our Trie implementation from last week, and used several recurring ideas from graph search problems. Next week we’re going to switch gears a bit and start discussing dynamic programming.
To learn more about all of these problem concepts, you need to take a look at Solve.hs. It gives a fairly comprehensive look at problem solving concepts in Haskell. If you want to understand how to shape your functions to work with folds like we did in this article, you’ll learn about that in Module 1. If you want to implement and apply data structures like graphs and tries, Module 2 will teach you. And if you want practice writing and using key graph algorithms in Haskell, Module 3 will give you the experience you need!
Writing Our Own Structure: Tries in Haskell & Rust
In the last few weeks we’ve studied a few different graph problems. Graphs are interesting because they are a derived structure that we can represent in different ways to solve different problems. Today, we’ll solve a LeetCode problem that actually focuses on writing a data structure ourselves to satisfy certain requirements! Next week, we’ll use this structure to solve a problem.
If you want to improve your Haskell Data Structure skills, both with built-in types and in making your own types, your should check out Solve.hs, our problem solving course. Module 2 is heavily focused on data structures, so it will clear up a lot of blind spots you might have working with these in Haskell!
The Problem
Unlike our previous problems, we’re not trying to solve some peculiar question formulation with inputs and outputs. Our task today is to implement the basic functions for a Trie data structure. A Trie (pronounced “try”) is also known as a Prefix tree. They are most often used in the context of strings (though other stream-like types are also possible). We’ll make one that is effectively a container of strings that efficiently supports 3 operations:
- Insert - Add a new word into our set
- Search - Determine if we have previously inserted the given word into our tree
- Starts With - Determine if we have inserted any word that has the given input as a prefix
The first two operations are typical of tree sets and hash sets, but the third operation is distinctive for a Trie.
We’ll implement these three functions, as well as provide a means of constructing an empty Trie. We’ll work with the constraint that all our input strings consist only of lowercase English letters.
The Algorithm
If at first you pronounced “Trie” like “Tree”, you’re not really wrong. Our core implementation strategy will be to create a recursive tree structure. It’s easiest if we start by visualizing a trie. Here is a tree structure corresponding to a Trie containing the words “at”, “ate”, “an”, “bank” and “band”.
:
_
/ \
a b
/ \ \
t* n* a
/ \
e* n
/ \
k* d*
The top node is a blank space _ representing the root node. All other nodes in our tree correspond to letters. When we trace a path from the root to any node, we get a valid prefix of a word in our Trie. A star (*) indicates that a node corresponds to a complete word. Note that interior nodes can be complete, as is the case with at.
This suggests a structure for each node in the Trie. A node should store a boolean value telling us if the node completes a word. Other than that, all it needs is a map keying from characters to other nodes further down the tree.
We can use this structure to write all our function implementations in relatively simple recursive terms.
Haskell
Recursive data structures and functions are very natural in Haskell, so we’ll start with that implementation. We’ll make our data type and provide it with the two fields…the boolean indicating the end of a word, and the map of characters to additional Trie nodes.
data Trie = Trie Bool (M.Map Char Trie)
Note that even though a node is visually represented by a particular character, we don’t actually need to store the character on the node. The fact that we arrive at a particular node by keying on a character from its parent is enough.
Now let’s write our implementations, starting with insert. In Haskell, we don’t write functions “on” the type because we can’t mutate expressions. We write functions that take one instance of the type and return another. So our insertTrie function has this signature:
insertTrie :: String -> Trie -> Trie
We want to build a recursive structure, and we have an input (String) that breaks down recursively. This means we have two cases to deal with in this function. Either the string is empty, or it is not:
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie ends subs) = ...
insertTrie (c : cs) (Trie ends subs) = ...
If the string is empty, we don’t need to do much. We return a node with the same sub-tries, but its Bool field is True now! It’s good to remark on certain edge cases. For example, this tells us that we can insert the “empty” string into our Trie. We would just mark the “root” node as True!
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) = ...
In the recursive case, we’ll be making a recursive call on a particular “sub” Trie. We want to “lookup” if a Trie for the character c already exists. If not, we’ll make a default one (with False and empty sub-tries map).
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
...
Now we recursively insert the rest of the input into this sub-Trie:
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
...
Finally, we map this newSub into the original Trie, using c as the key for it:
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
in (Trie ends (M.insert c newSub subs))
The search and startsWith functions follow a similar pattern, pattern matching on the input string. With search, an empty string keys us to look at the Bool field for our Trie. If it’s True, then the word we are searching for was inserted into our Trie:
searchTrie :: String -> Trie -> Bool
searchTrie [] (Trie ends _) = ends
searchTrie (c : cs) (Trie _ subs) = ...
If not, we’ll check for the sub-Trie using the character c. If it doesn’t exist, then the word we’re looking for isn’t in our Trie. If it does, we recursively search for the “rest” of the string in that Trie:
searchTrie :: String -> Trie -> Bool
searchTrie [] (Trie ends _) = ends
searchTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> searchTrie cs sub
Finally, startsWith is almost identical to search. The only difference is that if we reach the end of the input word, we always return True, as it only needs to be a prefix:
startsWithTrie :: String -> Trie -> Bool
startsWithTrie [] _ = True
startsWithTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> startsWithTrie cs sub
There is one interesting case here. We always consider the empty string to be a valid prefix, even if we haven’t inserted anything into our Trie. Perhaps this doesn’t make sense to you, but LeetCode accepts this logic with our Rust solution. You could work around to accommodate it, but it results in code that is less clean.
Here’s the full Haskell solution:
data Trie = Trie Bool (M.Map Char Trie)
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
in (Trie ends (M.insert c newSub subs))
searchTrie :: String -> Trie -> Bool
searchTrie [] (Trie ends _) = ends
searchTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> searchTrie cs sub
startsWithTrie :: String -> Trie -> Bool
startsWithTrie [] _ = True
startsWithTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> startsWithTrie cs sub
Rust
The Rust solution follows the same algorithmic ideas, but the code looks quite a bit different. Rust allows mutable data structures, so it looks a bit more like your typical object oriented language in making structures. But there are some interesting quirks! Here’s the frame LeetCode gives you to work with:
struct Trie {
...
}
impl Trie {
fn new() -> Self {
...
}
fn insert(&mut self, word: String) {
...
}
fn search(&self, word: String) -> bool {
...
}
fn starts_with(&self, prefix: String) -> bool {
...
}
}
With Rust we define the fields of the struct, and then create an impl for the type with all the relevant functions. Each “class” method takes a self parameter, somewhat like Python. This can be mutable or not. A Rust “constructor” is typically done with a new function, as you see.
Despite Rust’s trickiness with ownership and borrowing, there are no obstacles to making a recursive data type in the same manner as our Haskell implementation. Here’s our struct definition, as well as the constructor:
use std::collections::HashMap;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
...
}
When it comes to the main functions though, we don’t want to make them directly recursive. Each takes a String input, and constructing a new String that pops the first character actually isn’t efficient. Rust is also peculiar in that you cannot index into strings, due to ambiguity arising from character encodings. So our solution will be to use the Chars iterator to efficiently “pop” characters while being able to examine the characters that come next.
So let’s start by making ...It versions of all our functions that take Chars iterators. We’ll be able to call these functions recursively. We invoke each one from the base function by calling .chars() on the input string.
use std::collections::HashMap;
use std::str::Chars;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
fn insert(&mut self, word: String) {
self.insertIt(word.chars());
}
fn insertIt(&mut self, mut iter: Chars) {
...
}
fn search(&self, word: String) -> bool {
return self.searchIt(word.chars());
}
fn searchIt(&self, mut iter: Chars) -> bool {
...
}
fn starts_with(&self, prefix: String) -> bool {
return self.startsWithIt(prefix.chars());
}
fn startsWithIt(&self, mut iter: Chars) -> bool {
...
}
}
Now let’s zero in on these implementations, starting with insertIt. First we pattern match on the iterator. If it gives us None, we just end self.endsWith = true and we’re done.
impl Trie {
...
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
...
} else {
self.endsWord = true;
}
}
}
Now, just like in Haskell, if this node doesn’t have a sub-Trie for the character c yet, we insert a new Trie for c. Then we recursively call insertIt on this “subTrie”.
impl Trie {
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
if !self.nodes.contains_key(&c) {
self.nodes.insert(c, Trie::new());
}
if let Some(subTrie) = self.nodes.get_mut(&c) {
subTrie.insertIt(iter);
}
} else {
self.endsWord = true;
}
}
}
That’s it for insertion. Now for searching, we’ll follow the same pattern matching protocol. If the Chars iterator is empty, we just check if this node has endsWith set or not:
impl Trie {
fn searchIt(&self, mut iter: Chars) -> bool {
...
} else {
return self.endsWord;
}
}
}
We check again for a sub-Trie under the key c. If it doesn’t exist, we return false. If it does, we just make a recursive call!
impl Trie {
fn searchIt(&self, mut iter: Chars) -> bool {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.searchIt(iter);
} else {
return false;
}
} else {
return self.endsWord;
}
}
}
And with startsWith, it’s the same pattern. We do exactly the same thing as search, except that the else case is unambiguously true.
impl Trie {
fn startsWithIt(&self, mut iter: Chars) -> bool {
if let Some(c) = iter.next() {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.startsWithIt(iter);
} else {
return false;
}
} else {
return true;
}
}
}
Here’s our final Rust implementation:
use std::collections::HashMap;
use std::str::Chars;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
fn insert(&mut self, word: String) {
self.insertIt(word.chars());
}
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
if !self.nodes.contains_key(&c) {
self.nodes.insert(c, Trie::new());
}
if let Some(subTrie) = self.nodes.get_mut(&c) {
subTrie.insertIt(iter);
}
} else {
self.endsWord = true;
}
}
fn search(&self, word: String) -> bool {
return self.searchIt(word.chars());
}
fn searchIt(&self, mut iter: Chars) -> bool {
if let Some(c) = iter.next() {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.searchIt(iter);
} else {
return false;
}
} else {
return self.endsWord;
}
}
fn starts_with(&self, prefix: String) -> bool {
return self.startsWithIt(prefix.chars());
}
fn startsWithIt(&self, mut iter: Chars) -> bool {
if let Some(c) = iter.next() {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.startsWithIt(iter);
} else {
return false;
}
} else {
return true;
}
}
}
Conclusion
It’s always interesting to practice making recursive data structures in a new language. While Rust shares some things in common with Haskell, making data structures still feels more like other object-oriented languages than Haskell. Next week, we’ll put our Trie implementation to use by solving a problem that requires this data structure!
To learn more about writing your own data structures in Haskell, take our course, Solve.hs! Module 2 focuses heavily on data structures, and you’ll learn how to make some derived data structures that will improve your programs!
Topological Sort: Managing Mutable Structures in Haskell
Welcome back to our Rust vs. Haskell comparison series, featuring some of the most common LeetCode questions. We’ve done a couple graph problems the last two weeks, involving DFS and BFS.
Today we’ll do a graph problem involving a slightly more complicated algorithm. We’ll also use a couple data structures we haven’t seen in this series yet, and we’ll see how tricky it can get to have multiple mutable structures in a Haskell algorithm.
To learn all the details of managing your data structures in Haskell, check out Solve.hs, our problem solving course. You’ll learn all the key APIs, important algorithms, and you’ll get a lot of practice with LeetCode style questions!
The Problem
Today’s problem is called Course Schedule. We are given a number of courses, and a list of prerequisites among those courses. For a prerequisite pair (A,B), we cannot take Course A until we have taken Course B. Our job is to determine, in a sense, if the prerequisite list is well-defined. We want to see whether or not the list would actually allow us to take all the courses.
As an example, suppose we had these inputs:
Number Courses: 4
Prerequisites: [(2, 0), (1,0), (3,1), (3,2)]
This is a well defined set of courses. In order to take courses 1 and 2, we must take course 0. Then in order to take course 3, we have to take courses 1 and 2. So if we have the ordering 0->1->2->3, we can take all the courses. So we would return True.
However, if we were to add (1,3) there, we would not be able to take all the courses. We could take courses 0 and 2, but then we would be stuck because 1 and 3 have a mutual dependency. So we would return False with this list.
We are guaranteed that the course indices in the prerequisites list are in the range [0, numCourses - 1]. We are also guaranteed that all prerequisites are unique.
The Algorithm
For our algorithm, we will image these courses as living in a directed graph. If course A is a prerequisite of Course B, there should be a directed edge from A to B. This problem essentially boils down to determining if this graph has a cycle or not.
There are many ways to approach this, including relying on DFS or BFS as we discussed in the past two weeks! However, to introduce a new idea, we’ll solve this problem using the idea of topological sorting.
We can think of nodes as having “in degrees”. The “in degree” of a node is the number of directed edges coming into it. We are particularly concerned with nodes that have an in degree of 0. These are courses with no prerequisites, which we can take immediately.
Each time we “take” a course, we can increment a count of the courses we’ve taken, and then we can “remove” that node from the graph by decrementing the in degrees of all nodes that it is pointing to. If any of these nodes have their in degrees drop to 0 as a result of this, we can then add them to a queue of “0 degree nodes”.
If, once the queue is exhausted, we’ve taken every course, then we have proven that we can satisfy all the requirements! If not, then there must be a cycle preventing some nodes from ever having in-degree 0.
Rust Solution
We’ll start with a Rust solution. We need to manage a few different structures in this problem. The first two will be vectors giving us information about each course. We want to know the current “in degree” as well as having a list of the courses “unlocked” by each course.
Each “prerequisite” pair gives the unlocked course first, and then the prerequisite course. We’ll call these “post” and “pre”, respectively. We increase the in-degree of “post” and add “post” to the list of courses unlocked by “pre”:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
// More convenient to use usize
let n = num_courses as usize;
let mut inDegrees = Vec::with_capacity(n);
inDegrees.resize(n, 0);
// Maps from “pre” course to “post” course
let mut unlocks: Vec<Vec<usize>> = Vec::with_capacity(n);
unlocks.resize(n, Vec::new());
for req in prerequisites {
let post = req[0] as usize;
let pre = req[1] as usize;
inDegrees[post] += 1;
unlocks[pre].push(post);
}
...
}
Now we need to make a queue of 0-degree nodes. This uses VecDeque from last time. We’ll go through the initial in-degrees list and add all the nodes that are already 0. Then we’ll set up our loop to pop the front element until empty:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
let n = num_courses as usize;
...
// Make a queue of 0 degree
let mut queue: VecDeque<usize> = VecDeque::new();
for i in 0..(num_courses as usize) {
if inDegrees[i] == 0 {
queue.push_back(i);
}
}
let mut numSatisfied = 0;
while let Some(course) = queue.pop_front() {
...
}
return numSatisfied == num_courses;
}
All we have to do now is process the course at the top of the queue each time now. We always increment the number of courses satisfied, since de-queuing a course indicates we are taking it. Then we loop through unlocks and decrement each of their in degrees. If reducing an in-degree takes it to 0, then we add this unlocked course to the back of the queue:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
let n = num_courses as usize;
...
let mut numSatisfied = 0;
while let Some(course) = queue.pop_front() {
numSatisfied += 1;
for post in &unlocks[course] {
inDegrees[*post] -= 1;
if (inDegrees[*post] == 0) {
queue.push_back(*post);
}
}
}
return numSatisfied == num_courses;
}
This completes our solution! Here is the full Rust implementation:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
let n = num_courses as usize;
// Make a vector with inDegree Count
let mut inDegrees = Vec::with_capacity(n);
inDegrees.resize(n, 0);
// Make a vector of "unlocks"
let mut unlocks: Vec<Vec<usize>> = Vec::with_capacity(n);
unlocks.resize(n, Vec::new());
for req in prerequisites {
let post = req[0] as usize;
let pre = req[1] as usize;
inDegrees[post] += 1;
unlocks[pre].push(post);
}
// Make a queue of 0 degree
let mut queue: VecDeque<usize> = VecDeque::new();
for i in 0..(num_courses as usize) {
if inDegrees[i] == 0 {
queue.push_back(i);
}
}
let mut numSatisfied = 0;
while let Some(course) = queue.pop_front() {
numSatisfied += 1;
for post in &unlocks[course] {
inDegrees[*post] -= 1;
if (inDegrees[*post] == 0) {
queue.push_back(*post);
}
}
}
return numSatisfied == num_courses;
}
Haskell Solution
In Haskell, we can follow this same approach. However, this is a somewhat challenging algorithm for Haskell beginners, because there are a lot of data structure “modifications” occurring, and expressions in Haskell are immutable! So we’ll organize our solution into three different parts:
- Initializing our structures
- Writing loop modifiers
- Writing the loop
This solution will introduce 2 data structures we haven’t used in this series so far. The IntMap and the Sequence (Seq), which we’ll use qualified like so:
import qualified Data.IntMap.Lazy as IM
import qualified Data.Sequence as Seq
The IntMap type works more or less exactly like a normal Map, with the same API. However, it assumes we have Int as our key type, which makes certain operations more efficient than a generic ordered map.
Then Seq is the best thing to use for a FIFO queue. We would have used this last week if we implemented BFS from scratch.
We’ll also make a few type alias, since we’ll be combining these structures and frequently using them in type signatures:
type DegCount = IM.IntMap Int
type CourseMaps = (DegCount, IM.IntMap [Int])
type CourseState = (Int, Seq.Seq Int, DegCount)
The setup to our problem is fairly simple. Our function takes the number of courses as an integer, and the prerequisites as a list of tuples. We’ll write a number of helper functions beneath this top level definition, but for additional clarity, we’ll show them independently as we write them.
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = ...
Initializing Our Structures
Recall that the first part of our Rust solution focused on populating 3 structures:
- The list of in-degrees (per node)
- The list of “unlocks” (per node)
- The initial queue of 0-degree nodes
We use IntMaps for the first two (and use the alias DegCount for the first). These are easier to modify than vectors in Haskell. The other noteworthy fact is that we want to create these together (this is why we have the CourseMaps alias combining them). We process each prerequisite pair, updating both of these maps. This means we want to write a folding function like so:
processPrereq :: (Int, Int) -> CourseMaps -> CourseMaps
For this function, we want to define two more helpers. One that will make it easier to increment the key of a degree value, and one that will make it easy to append a new unlock for the other mapping.
incKey :: Int -> DegCount -> DegCount
appendUnlock :: Int -> Int -> IM.IntMap [Int] -> IM.IntMap [Int]
These two helpers are straightforward to implement. In each case, we check for the key existing. If it doesn’t exist, we insert the default value (either 1 or a singleton list). If it exists, we either increment the value for the degree, or we append the new unlocked course to the existing list.
incKey :: Int -> DegCount -> DegCount
incKey k mp = case IM.lookup k mp of
Nothing -> IM.insert k 1 mp
Just x -> IM.insert k (x + 1) mp
appendUnlock :: Int -> Int -> IM.IntMap [Int] -> IM.IntMap [Int]
appendUnlock pre post mp = case IM.lookup pre mp of
Nothing -> IM.insert pre [post] mp
Just prev -> IM.insert pre (post : prev) mp
Now it’s very tidy to implement our folding function, and apply it to get these initial values:
processPrereq :: (Int, Int) -> CourseMaps -> CourseMaps
processPrereq (post, pre) (inDegrees', unlocks') =
(incKey post inDegrees', appendUnlock pre post unlocks')
Here’s where our function currently is then:
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs =
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
Now we want to build our initial queue as well. For this, we just want to loop through the possible course numbers, and add any that are not in the map for inDegrees (we never insert something with a value of 0).
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs =
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
Writing Loop Modifiers
Now we have to consider what structures are going to be part of our “loop” and how we’re going to modify them. The type alias CourseState already expresses our loop state. We want to track the number of courses satisfied so far, the queue of 0-degree nodes, and the remaining in-degree values.
The key modification is that we can reduce the in-degrees of remaining courses. When we do this, we want to know immediately if we reduced the in-degree to 0. So let’s write a function that decrements the value, except that it deletes the key entirely if it drops to 0. We’ll return a boolean indicating if the key no longer exists in the map after this process:
decKey :: Int -> DegCount -> (DegCount, Bool)
decKey key mp = case IM.lookup key mp of
Nothing -> (mp, True)
Just x -> if x <= 1
then (IM.delete key mp, True)
else (IM.insert key (x - 1) mp, False)
Now what’s the core function of the loop? When we “take” a course, we loop through its unlocks, reduce all their degrees, and track which ones are now 0. Since this is a loop that updates state (the remaining inDegrees), we want to write a folding function for it:
decDegree :: Int -> (DegCount, [Int]) -> (DegCount, [Int])
First we perform the decrement. Then if decKey returns True, we’ll add the course to our new0s list.
decDegree :: Int -> (DegCount, [Int]) -> (DegCount, [Int])
decDegree post (inDegrees', new0s) =
let (inDegrees'', removed) = decKey post inDegrees'
in (inDegrees'', if removed then (post : new0s) else new0s)
Writing the Loop
With all these helpers at our disposal, we can finally write our core loop. Recall the 3 parts of our loop state: the number of courses taken so far, the queue of 0-degree courses, and the in-degree values. This loop should just return the number of courses completed:
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = ...
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = ...
If the queue is empty, we just return our accumulated number. While we’re at it, the final action is to simply compare this loop result to total number of courses to get our final result:
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) -> ...
We we “pop” the first course off of the queue, we first get the list of “post” courses that could now be unlocked by this course. Then we can apply our decDegree helper to get the final inDegrees’’ map and the “new 0’s”.
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) ->
let posts = fromMaybe [] (IM.lookup course unlocks)
(inDegrees'', new0s) = foldr decDegree (inDegrees', []) posts
...
Finally, we append the new 0’s to the end of the queue, and we make our recursive call, completing the loop and the function!
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) ->
let posts = fromMaybe [] (IM.lookup course unlocks)
(inDegrees'', new0s) = foldr decDegree (inDegrees', []) posts
queue'' = foldl (Seq.|>) rest new0s
in loop (numSatisfied + 1, queue'', inDegrees'')
Here’s the full solution, from start to finish:
type DegCount = IM.IntMap Int
type CourseMaps = (DegCount, IM.IntMap [Int])
type CourseState = (Int, Seq.Seq Int, DegCount)
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
incKey :: Int -> DegCount -> DegCount
incKey k mp = case IM.lookup k mp of
Nothing -> IM.insert k 1 mp
Just x -> IM.insert k (x + 1) mp
appendUnlock :: Int -> Int -> IM.IntMap [Int] -> IM.IntMap [Int]
appendUnlock pre post mp = case IM.lookup pre mp of
Nothing -> IM.insert pre [post] mp
Just prev -> IM.insert pre (post : prev) mp
processPrereq :: (Int, Int) -> CourseMaps -> CourseMaps
processPrereq (post, pre) (inDegrees', unlocks') =
(incKey post inDegrees', appendUnlock pre post unlocks')
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
decKey :: Int -> DegCount -> (DegCount, Bool)
decKey key mp = case IM.lookup key mp of
Nothing -> (mp, True)
Just x -> if x <= 1
then (IM.delete key mp, True)
else (IM.insert key (x - 1) mp, False)
decDegree :: Int -> (DegCount, [Int]) -> (DegCount, [Int])
decDegree post (inDegrees', new0s) =
let (inDegrees'', removed) = decKey post inDegrees'
in (inDegrees'', if removed then (post : new0s) else new0s)
loop :: CourseState -> Int
loop (numSatisfied, queue', inDegrees') = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) ->
let posts = fromMaybe [] (IM.lookup course unlocks)
(inDegrees'', new0s) = foldr decDegree (inDegrees', []) posts
queue'' = foldl (Seq.|>) rest new0s
in loop (numSatisfied + 1, queue'', inDegrees'')
Conclusion
This problem showed the challenge of working with multiple mutable types in Haskell loops. You have to be very diligent about tracking what pieces are mutable, and you often need to write a lot of helper functions to keep your code clean. In our course, Solve.hs, you’ll learn about writing compound data structures to help you solve problems more cleanly. A Graph is one example, and you’ll also learn about occurrence maps, which we could have used in this problem.
That’s all for graphs right now. In the next couple weeks, we’ll cover the Trie, a compound data structure that can help with some very specific problems.
Graph Algorithms in Board Games!
For last week’s problem we started learning about graph algorithms, focusing on depth-first-search. Today we’ll do a problem from an old board game that will require us to use breadth-first-search. We’ll also learn about a special library in Haskell that lets us solve these types of problems without needing to implement all the details of these algorithms.
To learn more about this library and graph algorithms in Haskell, you should check out our problem solving course, Solve.hs! Module 3 of the course focuses on algorithms, with a special emphasis on graph algorithms!
The Problem
Today’s problem comes from a kids board game called Snakes and Ladders, which will take a little bit to explain. First, we imagine we have a square board in an N x N grid, where each cell is numbered 1 to N^2. The bottom left corner is always “1”, and numbers increase in a snake-like fashion. First the increase from left to right along the bottom row. Then they go from right to left in the next row, before reversing again. Here’s what the numbers look like for a 6x6 board:
36 35 34 33 32 31
25 26 27 28 29 30
24 23 22 21 20 19
13 14 15 16 17 18
12 11 10 9 8 7
1 2 3 4 5 6
The “goal” is to reach the highest numbered tile, which is either in the top left (for even grid sizes) or the top right (for odd grid sizes). One moves by rolling a 6-sided die. Given the number on the die, you are entitled to move that many spaces. The ordinary path of movement is following the increasing numbers.
As is, the game is a little boring. You just always want to roll the highest number you can. However, various cells on the grid are equipped with “snakes” or “ladders”, which can move you around the grid if your die roll would cause your turn to end where these items start. Ladders typically move you closer to the goal, snakes typically move you away from the goal. Here’s an illustrative picture of a board:
We can represent such a board by putting an integer on each cell. The integer -1 represents an ordinary cell, where you would simply proceed to the next cell in order. However, we can represent the start of each snake and ladder with a number corresponding to the cell number where you end up if your die role lands you there. Here’s an example:
-1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1
-1 35 -1 -1 13 -1
-1 -1 -1 -1 -1 -1
-1 15 -1 -1 -1 -1
This grid has two ladders. The first can take you from position 2 to position 15 (see the bottom left corner). The second can take you from position 14 to position 35. There is also a snake, that will take you back from position 17 to position 13. Note that no matter the layout, you can only follow one snake or ladder on a turn. If you end your turn at the beginning of a ladder, which takes you to the beginning of another ladder, you do not take the second ladder on that turn.
Our objective is to find the smallest number of dice rolls possible to reach the goal cell (which will always have -1). In this case, the answer is 4. Various combinations of 3 rolls can land us on 14, which will take us to 35. Then rolling 1 would take us to the goal.
It is possible to contrive a board where it is impossible to reach the goal! We need to handle these cases. In these situations we must return -1. Here is such a grid, with many snakes, all leading back to the start!
1 1 -1
1 1 1
-1 1 1
The Algorithm
This is a graph search problem where each step we take carries the same weight (one turn), and we are trying to find the shortest path. This makes it a canonical example of a Breadth First Search problem (BFS).
We solve BFS by maintaining a queue of search states. In our case, the search state might consist simply of our location, though we may also want to track the number of steps we needed to reach that location as part of the state.
We’ll have a single primary loop, where we remove the first element in our queue. We’ll find all its “neighbors” (the states reachable from that node), and place these on the end of the queue. Then we’ll continue processing.
BFS works out so that states with a lower “cost” (i.e. number of turns) will all be processed before any states with higher cost. This means that the first time we dequeue a goal state from our queue, we can be sure we have found the shortest path to that goal state.
As with last week’s problem, we’ll spend a fair amount of effort on our “neighbors” function, which is often the core of a graph solution. Once we have that in place, the mechanics of the graph search generally become quite easy.
Rust Solution
Once again we’ll start with Rust, because we’ll use a special trick in Haskell. As stated, we want to start with our neighbors function. We’ll represent a single location just using the integer representing it on the board, not its grid coordinates. So at root, we’re taking one usize and returning a vector of usize values. But we’ll also take the board (a 2D vector of integers) so we can follow the snakes and ladders. Finally, we’ll pass the size of the board (just N, since our board is always square) and the “goal” location so that we don’t have to recalculate these every time:
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
...
return results;
}
The basic idea of this function is that we’ll loop through the possible die rolls (1 to 6) and return the resulting location from each roll. If we find that the roll would take us past the goal, then we can safely break:
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
for i in 1..=6 {
if loc + i > goal {
break;
}
...
}
return results;
}
How do we actually get the resulting location? We need to use the board, but in order to use the board, we have to convert the location into 2D coordinates. So let’s just write the frame for a function converting a location into coordinates. We’ll fill it in later:
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
...
}
Assuming we have this function, the rest of our neighbors logic is easy. We check the corresponding value for the location in board. If it is -1, we just use our prior location added to the die roll. Otherwise, we use the location given in the cell:
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
for i in 1..=6 {
if loc + i > goal {
break;
}
let (row, col) = convert(n, loc + i);
let next = board[row][col];
if next == -1 {
results.push(loc + i);
} else {
results.push(next as usize);
}
}
return results;
}
So let’s fill in this conversion function. It’s tricky because of the snaking order of the board and because we start from the bottom (highest row index) and not the top. Nonetheless, we want to start by getting the quotient and remainder of our location with the side-length. (We subtract 1 since our locations are 1-indexed).
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
let rowBase = (loc - 1) / n;
let colBase = (loc - 1) % n;
...
}
To get the final row, we simply take n - rowBase - 1. The column is trickier. We need to consider if the row base is even or odd. If it is even, the row is going from left to right. Otherwise, it goes from right to left. In the first case, the modulo for the column gives us the right column. In the second case, we need to subtract from n like we did with rows.
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
let rowBase = (loc - 1) / n;
let colBase = (loc - 1) % n;
let row = n - rowBase - 1;
let col =
if rowBase % 2 == 0 {
colBase
} else {
n - colBase - 1
};
return (row, col);
}
But that’s all we need for conversion!
Now that our neighbors function is closed up, we can finally write the core solution. For the Rust solution, we’ll define our “search state” as including the location and the number of steps we took to reach it, so a tuple (usize, usize). We’ll create a VecDeque of these, which is Rust’s structure for a queue, and insert our initial state (location 1, count 0):
use std::collections::VecDeque;
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let n = board.len();
let goal = board.len() * board[0].len();
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
...
}
We also want to track the locations we’ve already visited. This will be a hash set of the locations but not the counts. This is necessary to prevent infinite loops. Once we’ve visited a location there is no advantage to considering it again on a later branch (with this problem at least). We’ll also follow the practice of considering a cell “visited” once it is enqueued.
use std::collections::VecDeque;
use std::collections::HashSet;
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let n = board.len();
let goal = board.len() * board[0].len();
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let mut visited = HashSet::new();
visited.insert(1);
...
}
Now we’ll run a loop popping the front of the queue and finding the “neighboring” locations. If our queue is empty, this indicates no path was possible, so we return -1.
use std::collections::VecDeque;
use std::collections::HashSet;
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let n = board.len();
let goal = board.len() * board[0].len();
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let mut visited = HashSet::new();
visited.insert(1);
while let Some((idx, count)) = queue.pop_front() {
let ns = neighbors(n, goal, &board, idx);
...
}
return -1;
}
Now processing each neighbor is simple. First, if the neighbor is the goal, we’re done! Just return the dequeued count plus 1. Otherwise, check if we’ve visited the neighbor before. If not, push it to the back of the queue, along with an increased count:
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let n = board.len();
let goal = board.len() * board[0].len();
let mut visited = HashSet::new();
visited.insert(1);
while let Some((idx, count)) = queue.pop_front() {
let ns = neighbors(n, goal, &board, idx);
for next in ns {
if next == goal {
return (count + 1) as i32;
}
if !visited.contains(&next) {
queue.push_back((next, count + 1));
visited.insert(next);
}
}
}
return -1;
}
This completes our BFS solution! Here is the complete code:
use std::collections::VecDeque;
use std::collections::HashSet;
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
let rowBase = (loc - 1) / n;
let colBase = (loc - 1) % n;
let row = n - rowBase - 1;
let col =
if rowBase % 2 == 0 {
colBase
} else {
n - colBase - 1
};
return (row, col);
}
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
for i in 1..=6 {
if loc + i > goal {
break;
}
let (row, col) = convert(n, loc + i);
let next = board[row][col];
if next == -1 {
results.push(loc + i);
} else {
results.push(next as usize);
}
}
return results;
}
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let n = board.len();
let goal = board.len() * board[0].len();
let mut visited = HashSet::new();
visited.insert(1);
while let Some((idx, count)) = queue.pop_front() {
let ns = neighbors(n, goal, &board, idx);
for next in ns {
if next == goal {
return (count + 1) as i32;
}
if !visited.contains(&next) {
queue.push_back((next, count + 1));
visited.insert(next);
}
}
}
return -1;
}
Haskell Solution
For our Haskell solution, we’re going to use a special shortcut. We’ll make use of the Algorithm.Search library to handle the mechanics of the BFS for us. The function we’ll use has this type signature (slightly simplified):
bfs :: (state -> [state]) -> (state -> Bool) -> state -> Maybe [state]
We provide 3 inputs. First is the “neighbors” function, taking one state and returning its neighbors. Second is the “goal” function, telling us if a state is our final goal state. Finally we give it the initial state. If a goal is reachable, we receive a path to that goal. If not, we receive Nothing. Since this library provides the full path for us automatically, we won’t track the number of steps in our state. Our “state” will simply be the location. So let’s begin by framing out our function:
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert :: Int -> (Int, Int)
neighbor :: Int -> Int
neighbors :: Int -> [Int]
Let’s start with convert. This follows the same rules we used in our Rust solution, so there’s not much to say here. We just have to make sure we account for non-zero start indices in Haskell arrays by adding minRow and minCol.
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert :: Int -> (Int, Int)
convert loc =
let (rowBase, colBase) = (loc - 1) `quotRem` n
row = minRow + (n - rowBase - 1)
col = minCol + if even rowBase then colBase else n - colBase - 1
in (row, col)
Now we’ll write a neighbor helper that converts a single location. This just makes our neighbors function a lot cleaner. We use the same logic of checking for -1 in the board, or else using the value we find there.
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert = ...
neighbor :: Int -> Int
neighbor loc =
let coord = convert loc
onBoard = board A.! coord
in if onBoard == -1 then loc else onBoard
Now we can write neighbors with a simple list comprehension. We look through each roll of 1-6, add it to the current location, filter if this location is past the goal, and then calculate the neighbor.
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert = ...
neighbor = ...
neighbors :: Int -> [Int]
neighbors loc =
[neighbor (loc + i) | i <- [1..6], loc + i <= goal]
Now for the coup-de-grace. We call bfs with our neighbors function. The “goal” function is just (== goal), and the starting state is just 1. It will return our shortest path, and so we just return its length:
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = case bfs neighbors (== goal) 1 of
Nothing -> (-1)
Just path -> length path
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert :: Int -> (Int, Int)
convert loc =
let (rowBase, colBase) = (loc - 1) `quotRem` n
row = minRow + (n - rowBase - 1)
col = minCol + if even rowBase then colBase else n - colBase - 1
in (row, col)
neighbor :: Int -> Int
neighbor loc =
let coord = convert loc
onBoard = board A.! coord
in if onBoard == -1 then loc else onBoard
neighbors :: Int -> [Int]
neighbors loc =
[neighbor (loc + i) | i <- [1..6], loc + i <= goal]
And that’s our complete Haskell solution!
Conclusion
If you take our Solve.hs course, Module 3 is your go-to for learning about graph algorithms! You’ll implement BFS from scratch in Haskell, and learn how to apply other helpers from Algorithm.Search. In next week’s article, we’ll do one more graph problem that goes beyond the basic ideas of DFS and BFS.
Starting out with Graph Algorithms: Basic DFS
For a few weeks now, we’ve been tackling problems related to data structures, with a sprinkling of algorithmic ideas in there. Last week, we covered sets and heaps. Prior to that, we considered Matrices and the binary search algorithm.
This week, we’ll cover our first graph problem! Graph problems often build on a lot of fundamental layers. You need to understand the algorithm itself. Then you need to use the right data structures to apply it. And you’ll also still need the core problem solving patterns at your disposal. These 3 areas correspond to the first 3 modules of Solve.hs, our Haskell problem solving course! Check out that course to level up your Haskell skills!
The Problem
Today’s problem is called Number of Islands. We are given a 2D array as input, where every cell is either “land” or “sea” (either represented as the characters 1 and 0, or True and False). We want to find the number of distinct islands in this grid. Two “land” cells are part of the same island if we can draw a path from one cell to the other that only uses other land cells and other travels up, down, left and right (but not diagonally).
Let’s suppose we have this example:
111000
100101
100111
110000
000110
This grid has 3 islands. The island in the top left corner comprises 7 connected cells. Then there’s a small island in the bottom right with only 2 cells. Finally, we have a third island in the middle right with 5 tiles. While it is diagonally adjacent to the first island, we do not count this as a connection.
The Algorithm
This is one of the most basic questions you’ll see that requires a graph search algorithm, like Depth-First-Search (DFS) or Breadth-First-Search (BFS). The basic principle is that we will select a starting coordinate for a search. We will use one of these algorithms to find all the land cells that are part of that cell’s island. We’ll then increment a counter for having found this island.
We need to track all the cells that are part of this island. We’ll then keep iterating for new start locations to find new islands, but we have to exclude any locations that have already been explored.
While BFS is certainly possible, the solutions we’ll write here will use DFS. Our solution will consist of 3 components:
- A “neighbors” function that finds all adjacent land tiles to a given tile.
- A “visit” function that will take a starting coordinate and populate a “visited” set with all of the cells on the same island as the starting coordinate.
- A core “loop” that will consider each coordinate as a possible starting value for an island.
This ordering represents more of a “bottom up” approach to solving the problem. Going “top down” also works, and may be easier if you’re unfamiliar with graph algorithms. But as you get more practice with them, you’ll get a feel for knowing the bottom layers you need right away.
Rust Solution
To write our solution, we’ll write the components in our bottom up order. We’ll start with our neighbors function. This will take the island grid (LeetCode supplies us with a Vec<Vec<char>>) and the current location. We’ll represent locations as tuples, (usize, usize). This function will return a vector of locations.
use std::collections::HashSet;
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
...
}
We’ll start this function by defining a few values. We want to know the length and width of the grid, as well as defining r and c to quickly reference the current location.
use std::collections::HashSet;
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
let m = grid.len();
let n = grid[0].len();
let r = loc.0;
let c = loc.1;
let mut result: Vec<(usize,usize)> = Vec::new();
...
}
Now we just have to look in each of the four directions. Each direction is included as long as it is a land tile and that it is not out of bounds. We’ll do our “visited” checks elsewhere.
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
let m = grid.len();
let n = grid[0].len();
let r = loc.0;
let c = loc.1;
let mut result: Vec<(usize,usize)> = Vec::new();
if (r > 0 && grid[r - 1][c] == '1') {
result.push((r - 1, c));
}
if (c > 0 && grid[r][c - 1] == '1') {
result.push((r, c - 1));
}
if (r + 1 < m && grid[r + 1][c] == '1') {
result.push((r + 1, c));
}
if (c + 1 < n && grid[r][c + 1] == '1') {
result.push((r, c + 1));
}
return result;
}
Now let’s write the visit function. Remember, this function’s purpose is to populate the visited set starting from a certain location. We’ll use a HashSet of tuples for the visited set.
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
...
}
First, we’ll check if this location is already visited and return if so. Otherwise we’ll insert it.
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
if (visited.contains(loc)) {
return;
visited.insert(*loc);
...
}
All we have to do now is find the neighbors of this location, and recursively “visit” each one of them!
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
if (visited.contains(loc)) {
return;
}
visited.insert(*loc);
let ns = neighbors(grid, visited, loc);
for n in ns {
visit(grid, &n);
}
}
We’re not quite done, as we have to loop through our grid to call this function on each possible start. This isn’t so bad though. We start our function by defining key terms.
pub fn num_islands(grid: Vec<Vec<char>>) -> i32 {
let m = grid.len();
let n = grid[0].len();
let mut visited: HashSet<(usize,usize)> = HashSet::new();
let mut islandCount = 0;
...
// islandCount will be our final result
return islandCount;
}
Now we’ll “loop” through each possible starting location:
pub fn num_islands(grid: Vec<Vec
The last question is, what do we do for each location? If the location is land AND it is still unvisited, we treat it as the start of a new island. This means we increase the island count and then “visit” the location. When we consider other cells on this island, they’re already visited, so we won’t increase the island count when we find them!
```rust
pub fn num_islands(grid: Vec<Vec<char>>) -> i32 {
let m = grid.len();
let n = grid[0].len();
let mut visited: HashSet<(usize,usize)> = HashSet::new();
let mut islandCount = 0;
for row in 0..m {
for col in 0..n {
islandCount += 1;
visit(&grid, &mut visited, &loc);
}
return islandCount;
}
Here’s our complete solution:
use std::collections::HashSet;
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
let m = grid.len();
let n = grid[0].len();
let r = loc.0;
let c = loc.1;
let mut result: Vec<(usize,usize)> = Vec::new();
if (r > 0 && grid[r - 1][c] == '1') {
result.push((r - 1, c));
}
if (c > 0 && grid[r][c - 1] == '1') {
result.push((r, c - 1));
}
if (r + 1 < m && grid[r + 1][c] == '1') {
result.push((r + 1, c));
}
if (c + 1 < n && grid[r][c + 1] == '1') {
result.push((r, c + 1));
}
return result;
}
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
if (visited.contains(loc)) {
return;
}
visited.insert(*loc);
let ns = neighbors(grid, visited, loc);
for n in ns {
visit(grid, &n);
}
}
pub fn num_islands(grid: Vec<Vec<char>>) -> i32 {
let m = grid.len();
let n = grid[0].len();
let mut visited: HashSet<(usize,usize)> = HashSet::new();
let mut islandCount = 0;
for row in 0..m {
for col in 0..n {
let loc: (usize,usize) = (row,col);
if grid[row][col] == '1' && !(visited.contains(&loc)) {
islandCount += 1;
visit(&grid, &mut visited, &loc);
}
}
}
return islandCount;
}
Haskell Solution
The structure of this solution translates well to Haskell, since it’s a recursive solution at its root. We’ll just use a couple of folds to handle the other loops. Let’s outline the solution:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
Since we’re writing our functions “in line”, we don’t need to pass the grid around like we did in our Rust solution (though inline functions are also possible there). What you should observe immediately is that visit and loop have a similar structure. They both fit into the a -> b -> b pattern we want for foldr! We’ll use this to great effect!
But first, let’s fill in neighbors. Each of the 4 directions requires the same two conditions we used before. We make sure it’s not out of bounds, and that the next tile is “land”. Here’s how we check the “up” direction:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
neighbors (row, col) =
let up = if row > minRow && grid A.! (row - 1, col) then Just (row - 1, col) else Nothing
...
We return Nothing if it is not a valid neighbor. Then we just combine the four directional options with catMaybes to complete this helper:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
neighbors (row, col) =
let up = if row > minRow && grid A.! (row - 1, col) then Just (row - 1, col) else Nothing
left = if col > minCol && grid A.! (row, col - 1) then Just (row, col - 1) else Nothing
down = if row < maxRow && grid A.! (row + 1, col) then Just (row + 1, col) else Nothing
right = if col < maxCol && grid A.! (row, col + 1) then Just (row, col + 1) else Nothing
in catMaybes [up, left, down, right]
Now we start the visit function by checking if we’ve already visited the location, and add it to the set if not:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
visit coord visited = if HS.member coord visited then visited else
let visited' = HS.insert coord visited
...
Now we have to get the neighbors and “loop” through the neighbors so that we keep the visited set updated. This is where we’ll apply our first fold. We’ll recursively fold over visit on each of the possible neighbors, which will give us the final visited set from this process. That’s all we need for this helper!
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
visit coord visited = if HS.member coord visited then visited else
let visited' = HS.insert coord visited
in foldr visit visited’ ns
Now our loop function will consider only a single coordinate. We think of this as having two pieces of state. First, the number of accumulated islands (the Int). Second, we have the visited set. So we check if the coordinate is unvisited land. If so, we increase the count, and get our “new” visited set by calling visit on it. If not, we return the original inputs.
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
loop coord (count, visited) = if grid A.! coord && not (HS.member coord visited)
then (count + 1, visit coord visited)
else (count, visited)
Now for the final flourish. Our loop also has the structure for foldr. So we’ll loop over all the indices of our array, which will give us the final number of islands and the visited set. Our final answer is just the fst of these:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = fst (foldr loop (0, HS.empty) (A.indices grid))
where
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
Here’s our final solution:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = fst (foldr loop (0, HS.empty) (A.indices grid))
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
neighbors (row, col) =
let up = if row > minRow && grid A.! (row - 1, col) then Just (row - 1, col) else Nothing
left = if col > minCol && grid A.! (row, col - 1) then Just (row, col - 1) else Nothing
down = if row < maxRow && grid A.! (row + 1, col) then Just (row + 1, col) else Nothing
right = if col < maxCol && grid A.! (row, col + 1) then Just (row, col + 1) else Nothing
in catMaybes [up, left, down, right]
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
visit coord visited = if HS.member coord visited then visited else
let visited' = HS.insert coord visited
ns = neighbors coord
in foldr visit visited' ns
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
loop coord (count, visited) = if grid A.! coord && not (HS.member coord visited)
then (count + 1, visit coord visited)
else (count, visited)
A Note on the Graph Algorithm
It seems like we solved this without even really apply a “graph” algorithm! We just did a loop and a recursive call and everything worked out! There are a couple elements of this problem that make it one of the easiest graph problems out there.
Normally, we have to use some kind of structure to store a search state, telling us the next nodes in our graph to search. For BFS, this is a queue. For Dijkstra or A*, it is a heap. For DFS it is normally a stack. However, we are using the call stack to act as the stack for us!
When we make a recursive call to “visit” a location, we don’t need to keep track of which node we return to after we’re done. The function returns, and the prior node is already sitting there on the call stack.
The other simplifying factor is that we don’t need to do any backtracking or state restoration. Sometimes with a DFS, you need to “undo” some of the steps you took if you don’t find your goal. But this algorithm is just a space-filling algorithm. We are just trying to populate our “visited” set, and we never take nodes out of this set once we have visited them.
Conclusion
We’ve got a couple more graph problems coming up next. If you want to learn more about applying graph algorithms in Haskell (including implementing them for yourself!) check out Solve.hs, where Module 3 will teach you about algorithms including DFS, BFS, A* and beyond!
Sets & Heaps in Haskell and Rust
In the last few articles, we’ve spent some time doing manual algorithms with binary trees. But most often, you won’t have to work with binary trees yourself, as they are built in to the operations of other data structures.
Today, we’ll solve a problem that relies on using the “set” data structure as well as the “heap” data structure. To learn more details about using these structures in Haskell, sign up for our Solve.hs course, where you’ll spend an entire module on data structures!
The Problem
Today’s problem is Find k pairs with smallest sums. Our problem input consists of two sorted arrays of integers as well as a number k. Our job is to return the k pairs of numbers that have the smallest sum, where a “pair” consists of one number from the first array, and one number from the second array.
Here’s an example input and output:
inputNums1 = [1,10,101,102]
inputNums2 = [7,8,9,10,11]
k = 11
output =
[ (1,7), (1,8), (1,9), (1,10), (1,11)
, (10,7), (10,8), (10,9), (10,10)
, (10,11), (101,7)
]
Observe that we are returning the numbers themselves in pairs, rather than the sums, and not the indices of the numbers.
While the (implicit) indices of each pair must be unique, the numbers do not need to be unique. For each, if we have the lists [1,1], [5, 7], and k is 2, we should return [(1,5), (1,5)], where both of the 1’s from the first list are paired with the 5 from the second list.
The Algorithm
At the center of our algorithm is a min-heap. We want to store elements that contain pairs of numbers. But we want those elements to be ordered by their sum. We also want the element to contain the corresponding indices for the each number in their array.
We’ll start by making a pair of the first element from each array, and inserting that into our heap, because this pair must be the smallest. Then we’ll do a loop where we extract the minimum from the heap, and then try adding its “neighbors” into the heap. A “neighbor” of the index pair (i1, i2) comes from incrementing one of the two indices, so either (i1 + 1, i2) or (i1, i2 + 1). We continue until we have extracted k elements (or our heap is empty).
Each time we insert a pair of numbers into the heap, we’ll insert the pair of indices into a set. This will allow us to avoid double counting any pairs.
So using the first example in the section above, here are the first few steps of our process. Inside the heap is a 3-tuple with the sum, the indices (a 2-tuple), and the values (another 2-tuple).
Step 1:
Heap: [(8, (0,0) (1, 7))]
Visited: [(0,0)]
Output: []
Step 2:
Heap: [(9, (0,1), (1,8)), ]
Visited: [(0,0), (0,1), (1,0)]
Output: [(1,7)]
Step 3:
Heap: [(10, (0,2), (1,9)), (17, (1,0), (10,7)), (18, (1,1), (10,8))]
Visited: [(0,0), (0,1), (0,2), (1,0), (1,1)]
Output: [(1,7), (1,8)]
And so we would continue this process until we gathered k outputs.
Haskell Solution
The core structure of this problem isn’t hard. We define some initial terms, such as our data structures and outputs, and then we have a single loop. We’ll express this single loop as a recursive function. It will take the number of remaining values, the visited set, the heap, and the accumulated output, and ultimately return the output. Throughout this problem we’ll use the type alias I2 to refer to a tuple of two integers.
import qualified Data.Heap as H
import qualified Data.Set as S
type I2 = (Int, Int)
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
n1 = V.length arr1
n2 = V.length arr2
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f remaining visited heap acc = ...
Let’s fill in our loop. We can start with the edge cases. If the k is 0, or if our heap is empty, we should return the results list (in reverse).
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) -> ...
Our primary case now results from extracting the min element of the heap. We don’t actually need its sum. We just put that in the first position of the tuple so that it is the primary sorting value for the heap. Let’s define the next possible coordinate pairs from adding to i1 and i2, as well as the new sums we get from using those indices:
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
in ...
Now we need to try adding these values to the remaining heap. If the index is too large, or if we’ve already visited the coordinate, we don’t add the new value, returning the old heap. Otherwise we insert it into our heap. For the second value, we just use heap1 (from trying to add the first value) as the baseline.
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
heap1 = if i1 + 1 < n1 && S.notMember c1 visited
then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
heap2 = if i2 + 1 < n2 && S.notMember c2 visited
then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
in ...
Now we complete our recursive loop function, by adding these new indices to the visited set and make a recursive call. We decrement the remaining number, and append the values to our accumulated list.
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
heap1 = if i1 + 1 < n1 && S.notMember c1 visited
then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
heap2 = if i2 + 1 < n2 && S.notMember c2 visited
then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
visited' = foldr S.insert visited ([c1, c2] :: [I2])
in f (remaining - 1) visited' heap2 ((v1,v2) : acc)
To complete the function, we define our initial heap and make the first call to our loop function:
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = f k (S.singleton (0,0)) initialHeap []
where
val1 = arr1 V.! 0
val2 = arr2 V.! 0
initialHeap = H.singleton (val1 + val2, (0,0), (val1, val2))
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f = ...
Now we’re done! Here’s our complete solution:
type I2 = (Int, Int)
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = f k (S.singleton (0,0)) initialHeap []
where
val1 = arr1 V.! 0
val2 = arr2 V.! 0
initialHeap = H.singleton (val1 + val2, (0,0), (val1, val2))
n1 = V.length arr1
n2 = V.length arr2
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
heap1 = if i1 + 1 < n1 && S.notMember c1 visited
then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
heap2 = if i2 + 1 < n2 && S.notMember c2 visited
then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
visited' = foldr S.insert visited ([c1, c2] :: [I2])
in f (remaining - 1) visited' heap2 ((v1,v2) : acc)
Rust Solution
Now, on to our Rust solution. We’ll start by defining our terms. These follow the pattern laid out in our algorithm and the Haskell solution:
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
let mut heap: BinaryHeap<Reverse<(i32, (usize,usize), (i32,i32))>> = BinaryHeap::new();
let val1 = nums1[0];
let val2 = nums2[0];
heap.push(Reverse((val1 + val2, (0,0), (val1, val2))));
let mut visited: HashSet<(usize, usize)> = HashSet::new();
visited.insert((0,0));
let mut results = Vec::new();
let mut remaining = k;
let n1 = nums1.len();
let n2 = nums2.len();
...
return results;
}
The most interesting of these is the heap. We parameterize the type of the BinaryHeap using the same kind of tuple we had in Haskell. But in order to make it a “Min” heap, we have to wrap our values in the Reverse type.
Now let’s define the outline of our loop. We keep going until remaining is 0. We will also break if we can’t pop a value from the heap.
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
...
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
...
} else {
break;
}
remaining -= 1;
}
return results;
}
Now we’ll define our new coordinates, and add tests for whether or not these can be added to the heap:
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
...
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
let c1 = (i1 + 1, i2);
let c2 = (i1, i2 + 1);
if i1 + 1 < n1 && !visited.contains(&c1) {
let inc1 = nums1[i1 + 1] + v2;
...
}
if i2 + 1 < n2 && !visited.contains(&c2) {
let inc2 = v1 + nums2[i2 + 1];
...
}
} else {
break;
}
remaining -= 1;
}
return results;
}
In each case, we add the new coordinate to visited, and push the new element on to the heap. We also push the values onto our results array.
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
...
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
let c1 = (i1 + 1, i2);
let c2 = (i1, i2 + 1);
if i1 + 1 < n1 && !visited.contains(&c1) {
let inc1 = nums1[i1 + 1] + v2;
visited.insert(c1);
heap.push(Reverse((inc1, c1, (nums1[i1 + 1], v2))));
}
if i2 + 1 < n2 && !visited.contains(&c2) {
let inc2 = v1 + nums2[i2 + 1];
visited.insert(c2);
heap.push(Reverse((inc2, c2, (v1, nums2[i2 + 1]))));
}
results.push(vec![v1,v2]);
} else {
break;
}
remaining -= 1;
}
return results;
}
And that’s all we need!
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
let mut heap: BinaryHeap<Reverse<(i32, (usize,usize), (i32,i32))>> = BinaryHeap::new();
let val1 = nums1[0];
let val2 = nums2[0];
heap.push(Reverse((val1 + val2, (0,0), (val1, val2))));
let mut visited: HashSet<(usize, usize)> = HashSet::new();
visited.insert((0,0));
let mut results = Vec::new();
let mut remaining = k;
let n1 = nums1.len();
let n2 = nums2.len();
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
let c1 = (i1 + 1, i2);
let c2 = (i1, i2 + 1);
if i1 + 1 < n1 && !visited.contains(&c1) {
let inc1 = nums1[i1 + 1] + v2;
visited.insert(c1);
heap.push(Reverse((inc1, c1, (nums1[i1 + 1], v2))));
}
if i2 + 1 < n2 && !visited.contains(&c2) {
let inc2 = v1 + nums2[i2 + 1];
visited.insert(c2);
heap.push(Reverse((inc2, c2, (v1, nums2[i2 + 1]))));
}
results.push(vec![v1,v2]);
} else {
break;
}
remaining -= 1;
}
return results;
}
Conclusion
Next time, we’ll start exploring some graph problems, which also rely on data structures like we used here!
I didn’t explain too much in this article about the details of using these various data structures. If you want an in-depth exploration of how data structures work in Haskell, including the “common API” that helps you use almost all of them, you should sign up for our Solve.hs course! Module 2 is completely dedicated to teaching you about data structures, and you’ll get a lot of practice working with these structures in sample problems.
Binary Tree BFS: Zigzag Order
In our last article, we explored how to perform an in-order traversal of a binary search tree. Today we’ll do one final binary tree problem to solidify our understanding of some common tree patterns, as well as the tricky syntax for dealing with a binary tree in Rust.
If you want some interesting challenge problems using Haskell data structures, you should take our Solve.hs course. In particular, you’ll learn how to write a self-balancing binary tree to use for an ordered set!
The Problem
Today we will solve Zigzag Level Order Traversal. For any binary tree, we can think about it in terms of “levels” based on the number of steps from the root. So given this tree:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
We can visually see that there are 4 levels. So a normal level order traversal would return a list of 4 lists, where each list is a single level, ordered from left to right, visually speaking:
[45]
[32, 50]
[5, 40, 100]
[37, 43]
However, with a zigzag level order traversal, every other level is reversed. So we should get the following result for the input tree:
[45]
[50, 32]
[5, 40, 100]
[43, 37]
So we can imagine that we do the first level from left to right and then zigzag back to get the second level from right to level. Then we do left to right again for the third level, and so on.
The Algorithm
For our in-order traversal, we used a kind of depth-first search (DFS), and this approach is more common for tree-based problems. However, for a level-order problem, we want more of a breadth-first search (BFS). In a BFS, we explore states in order of their distance to the root. Since all nodes in a level have the same distance to the root, this makes sense.
Our general idea is that we’ll store a list of all the nodes from the prior level. Initially, this will just contain the root node. We’ll loop through this list, and create a new list of the values from the nodes in this list. This gets appended to our final result list.
While we’re doing this loop, we’ll also compose the list for the next level. The only trick is knowing whether to add each node’s left or right child to the next-level list first. This flips each iteration, so we’ll need a boolean tracking it that flips each time.
Once we encounter a level that produces no numbers (i.e. it only contains Nil nodes), we can stop iterating and return our list of lists.
Rust Solution
Now that we’re a bit more familiar with manipulating Rc RefCells, we’ll start with the Rust solution, framing it according to the two-loop structure in our algorithm. We’ll define stack1, which is the iteration stack, and stack2, where we accumulate the new nodes for the next layer. We also define our final result vector, a list of lists.
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
let mut result: Vec<Vec<i32>> = Vec::new();
let mut stack1: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
stack1.push(root.clone());
let mut stack2: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
let mut leftToRight = true;
...
return result;
}
Our initial loop will continue until stack1 no longer contains any elements. So our basic condition is while(!stack1.is_empty(). However, there’s another important element here.
After we accumulate the new nodes in stack2, we want to flip the meanings of our two stacks. We want our accumulated nodes referred to by stack1, and stack2 to be an empty list to accumulate. We accomplish this in Rust by clearing stack1 at the end of our loop, and then using std::mem::swap to flip their meanings:
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
let mut result: Vec<Vec<i32>> = Vec::new();
let mut stack1: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
stack1.push(root.clone());
let mut stack2: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
let mut leftToRight = true;
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new(); // Values from this level
...
leftToRight = !leftToRight;
stack1.clear();
mem::swap(&mut stack1, &mut stack2);
}
return result;
}
In C++ we could accomplish something like this using std::move, but only because we want stack1 to return to an empty state.
stack2 = std::move(stack1);
Also, observe that we flip our boolean flag at the end of the iteration.
Now let’s get to work on the inner loop. This will actually go through stack1, add values to thisLayer, and accumulates the next layer of nodes for stack2. An interesting finding is that whether we’re going left to right or vice versa, we want to loop through stack2 in reverse. This means we’re treating it like a true stack instead of a vector, first accessing the last node to be added.
A left-to-right pass will add lefts and then rights. This means the right-mode node in the next layer is on “top” of the stack, at the end of the vector. A right-to-left pass will first add the right child for a node before its left. This means the left-most node of the next layer is at the end of the vector.
Let’s frame up this loop, and also add the results of this layer to our final result vector.
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
...
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new()
for node in stack1.iter().rev() {
...
}
if (!thisLayer.is_empty()) {
result.push(thisLayer);
}
leftToRight = !leftToRight;
stack1.clear();
mem::swap(&mut stack1, &mut stack2);
}
return result;
}
Note that we do not add the values array if it is empty. We allow ourselves to accumulate None nodes in our stack. The final layer we encounter will actually consist of all None nodes, and we don’t want this layer to add an empty list.
Now all we need to do is populate the inner loop. We only take action if the node from stack1 is Some instead of None. Then we follow a few simple steps:
- Borrow the
TreeNodefrom thisRefCell - Push its value onto
thisLayer. - Add its children (using
clone) tostack2, in the right order.
Here’s the code:
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
...
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new()
for node in stack1.iter().rev() {
if let Some(current) = node {
let currentTreeNode = current.borrow();
thisLayer.push(currentTreeNode.val);
if leftToRight {
stack2.push(currentTreeNode.left.clone());
stack2.push(currentTreeNode.right.clone());
} else {
stack2.push(currentTreeNode.right.clone());
stack2.push(currentTreeNode.left.clone());
}
}
}
...
}
return result;
}
And now we’re done! Here’s the full solution:
use std::rc::Rc;
use std::cell::RefCell;
use std::mem;
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
let mut result: Vec<Vec<i32>> = Vec::new();
let mut stack1: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
stack1.push(root.clone());
let mut stack2: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
let mut leftToRight = true;
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new();
for node in stack1.iter().rev() {
if let Some(current) = node {
let currentTreeNode = current.borrow();
thisLayer.push(currentTreeNode.val);
if leftToRight {
stack2.push(currentTreeNode.left.clone());
stack2.push(currentTreeNode.right.clone());
} else {
stack2.push(currentTreeNode.right.clone());
stack2.push(currentTreeNode.left.clone());
}
}
}
if (!thisLayer.is_empty()) {
result.push(thisLayer);
}
leftToRight = !leftToRight;
stack1.clear();
mem::swap(&mut stack1, &mut stack2);
}
return result;
}
Haskell Solution
While our Rust solution was better described from the outside in, it’s easy to build the Haskell solution from the inside out. We have two loops, and we can start by defining the inner loop (we’ll call it the stack loop).
The goal of this loop is to take stack1 and turn it into stack2 (the next layer) and the numbers for this layer, while also tracking the direction of iteration. Both outputs are accumulated as lists, so we have inputs for them as well:
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = ...
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop isLeftToRight stack1 stack2 nums = ...
When stack1 is empty, we return our result from this loop. Because of list accumulation order, we reverse nums when giving the result. However, we don’t reverse stack2, because we want to iterate starting from the “top”. This seems like the opposite of what we did in Rust, because Rust uses a vector for its stack type, instead of a singly linked list!
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = ...
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop _ [] stack2 nums = (stack2, reverse nums)
stackLoop isLeftToRight (Nil : rest) stack2 numbers = stackLoop isLeftToRight rest stack2 numbers
stackLoop isLeftToRight (Node x left right : rest) stack2 nums = ...
Observe also a second edge case…for Nil nodes in stack1, we just recurse on the rest of the list. Now for the main case, we just define the new stack2, which adds the child nodes in the correct order. Then we recurse while also adding x to nums.
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = ...
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop _ [] stack2 nums = (stack2, reverse nums)
stackLoop isLeftToRight (Nil : rest) stack2 numbers = stackLoop isLeftToRight rest stack2 numbers
stackLoop isLeftToRight (Node x left right : rest) stack2 nums =
let stack2' = if isLeftToRight then right : left : stack2 else left : right : stack2
in stackLoop isLeftToRight rest stack2' (x : nums)
...
Now we’ll define the outer loop, which we’ll call the layerLoop. This takes the direction flag and stack1, plus the accumulator list for the results. It also has a simple base case to reverse the results list once stack1 is empty.
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = layerLoop True [root] []
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop = ...
layerLoop :: Bool -> [TreeNode] -> [[Int]] -> [[Int]]
layerLoop _ [] allNums = reverse allNums
layerLoop isLeftToRight stack1 allNums = ...
Now in the recursive case, we call the stackLoop to get our new numbers and the stack for the next layer (which we now think of as our new stack1). We then recurse, flipping the boolean flags and adding these new numbers to our results, but only if the list is not empty.
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = layerLoop True [root] []
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop = ...
layerLoop :: Bool -> [TreeNode] -> [[Int]] -> [[Int]]
layerLoop _ [] allNums = reverse allNums
layerLoop isLeftToRight stack1 allNums =
let (stack1', newNums) = stackLoop isLeftToRight stack1 [] []
in layerLoop (not isLeftToRight) stack1' (if null newNums then allNums else newNums : allNums)
The last step, as you seen is calling layerLoop from the start with root. We’re done! Here’s our final implementation:
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = layerLoop True [root] []
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop _ [] stack2 nums = (stack2, reverse nums)
stackLoop isLeftToRight (Nil : rest) stack2 numbers = stackLoop isLeftToRight rest stack2 numbers
stackLoop isLeftToRight (Node x left right : rest) stack2 nums =
let stack2' = if isLeftToRight then right : left : stack2 else left : right : stack2
in stackLoop isLeftToRight rest stack2' (x : nums)
layerLoop :: Bool -> [TreeNode] -> [[Int]] -> [[Int]]
layerLoop _ [] allNums = reverse allNums
layerLoop isLeftToRight stack1 allNums =
let (stack1', newNums) = stackLoop isLeftToRight stack1 [] []
in layerLoop (not isLeftToRight) stack1' (if null newNums then allNums else newNums : allNums)
Conclusion
That’s all we’ll do for binary trees right now. In the coming articles we’ll continue to explore more data structures as well as some common algorithms. If you want to learn more about data structures in algorithms in Haskell, check out our course Solve.hs. Modules 2 & 3 are filled with this sorts of content, including lots of practice problems.
In-Order Traversal in Haskell and Rust
Last time around, we started exploring binary trees. We began with a simple problem (inverting a tree), but encountered some of the difficulties implementing a recursive data structure in Rust.
Today we’ll do a slightly harder problem (LeetCode rates as “Medium” instead of “Easy”). This problem is also specifically working with a binary search tree instead of a simple binary tree. With a search tree, we have the property that the “values” on each node are orderable, and all the values to the “left” of any given node are no greater than that node’s value, and the values to the “right” are not smaller.
Binary search trees are the heart of any ordered Set type. In our problem solving course Solve.hs, you’ll get the chance to build a self-balancing binary search tree from scratch, which involves some really cool algorithmic tricks!
The Problem
Though it’s harder than our previous problem, today’s problem is still straightforward. We are taking an ordered binary search tree and finding the k-th smallest element in that tree, where k is the second input to our function.
So suppose our tree looks like this:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
Our input k is 1-indexed. So if we get 1 as our input, we should return 5, the smallest element in the tree. If we receive 4, we should return 40, the 4th smallest element after 5, 32 and 37. If we get 8, we’ll return 100, the largest element in the tree.
The Algorithm
Binary search trees are designed to give logarithmic time access, insertions, and deletions for elements. If our tree was annotated, so that each node stored the number of children it had, we’d be able to solve this problem in logarithmic time as well.
However, we want to assume a minimal tree design, where each node only holds its own value and pointers to its two children. With these constraints, our algorithm has to be linear in terms of the input k.
We’re going to solve this with an in-order traversal. We’re just going to traverse the elements of the BST in order from smallest value to largest, and count until we’ve encountered k elements. Then we’ll return the value at our “current” node.
An in-order traversal is conceptually simple. For a given node, we “visit” the left child, then visit the node itself, and then visit the right child. But the actual mechanics of doing this traversal can be a little tricky to think of on the spot if you haven’t practiced it before.
The main idea is that we’ll use a stack of nodes to track where we are in the tree. The stack traces a path from our “current” node back up its parents back to the root of the tree. Our algorithm looks like this:
First, we create a stack from the root, following all left child nodes until we reach a node without a left child. This node has the smallest element of the tree.
Now, we begin processing, always considering the top node in the stack, while tracking the number of elements remaining until we hit k. If that number is down to 1, then the value at the top node in the stack is our result.
If not, we’ll decrement the number of elements remaining, and then check the right child of this node. If the right child is Nil, we just pop this node from the stack and process its parent. If the right child does exist, we’ll add all of its left children to our stack as well.
Here’s how the stack looks like with our example tree, if we’re looking for k=7 (which should be 50).
[5, 32, 45] -- k = 1, Initial left children of root
[32, 45] -- k = 2, popped 5, no right child
[37, 40, 45] -- k = 3, popped 32, right child was 40, which added left child 37
[40, 45] -- k = 4, popped 37, no right child
[43, 45] -- k = 5, popped 40, and 43 is right child
[45] -- k = 6, popped 43
[50] -- k = 7, popped 45 and added 50, the right child (no left children)
Since 50 is on top of the stack with k = 7, we can return 50.
Haskell Solution
Let’s code this up! We’ll start with Haskell, since Rust is, once again, somewhat tricky due to TreeNode handling. To start, let’s remind ourselves of the recursive TreeNode type:
data TreeNode = Nil | Node Int TreeNode TreeNode
deriving (Show, Eq)
Now when writing up the algorithm, we first want to define a helper function addLeftNodesToStack. We’ll use this at the beginning of the algorithm, and then again each time we encounter a right child.
This helper will take a TreeNode and the existing stack, and return the modified stack.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack = ...
As far as recursive helpers, this is a simple one! If our input node is Nil, we return the original stack. We want to maintain the invariant that we never include Nil values in our stack! But if we have a value node, we just add it to the stack and recurse on its left child.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack Nil acc = acc
addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)
Now it’s time to implement our algorithm for finding the k-th element. This will be a recursive function that takes the number of elements remaining, as well as the current stack. We’ll call this initially with k and the stack we get from adding the left nodes of the root:
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack = ...
findK :: Int -> [TreeNode] -> Int
findK = ...
This function has a couple error cases. We expect a non-empty stack (our input k is constrained within the size of the tree), and we expect the top to be non-Nil. After that, we have our base case where k = 1, and we return the value at this node.
Finally, we get our recursive case. We decrement the remaining count, and add the left nodes of the right child of this node to the stack.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack = ...
findK :: Int -> [TreeNode] -> Int
findK k [] = error $ "Found empty list expecting k: " ++ show k
findK _ (Nil : _) = error "Added Nil to stack!"
findK 1 (Node x _ _ : _) = x
findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)
This completes our solution!
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack Nil acc = acc
addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)
findK :: Int -> [TreeNode] -> Int
findK k [] = error $ "Found empty list expecting k: " ++ show k
findK _ (Nil : _) = error "Added Nil to stack!"
findK 1 (Node x _ _ : _) = x
findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)
Rust Solution
In our Rust solution, we’re once again working with this TreeNode type, including the 3 wrapper layers:
#[derive(Debug, PartialEq, Eq)]
pub struct TreeNode {
pub val: i32,
pub left: Option<Rc<RefCell<TreeNode>>>,
pub right: Option<Rc<RefCell<TreeNode>>>,
}
Our first step will be to implement the helper function to add the “left” nodes. This function will take a “root” node as well as a mutable reference to the stack so we can add nodes to it.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
...
}
You’ll notice that stack does not actually use the Option wrapper, only Rc and RefCell. Remember in our Haskell solution that we want to enforce that we don’t add non-null nodes to the stack. This Rust solution enforces this constraint at compile time.
To implement this function, we’ll use the same trick we did when inverting trees to pattern match on node and detect if it is Some or None. If it is None, we don’t have to do anything.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
...
}
}
Since current is now unwrapped from Option, we can push it to the stack. As in our previous problem though, we have to clone it first! We need a clone of the reference (as wrapped by Rc because the stack will now have to own this reference.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
...
}
}
Now we’ll recurse on the left subchild of current. In order to unwrap the TreeNode from Rc/RefCell, we have to use borrow. Then we can grab the left value. But again, we have to clone it before we make the recursive call. Here’s the final implementation of this helper:
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
add_left_nodes_to_stack(current.borrow().left.clone(), stack);
}
}
We could have implemented the helper with a while loop instead of recursion. This would actually have used less memory in Rust! We would have to make some changes though, like making a new mut reference from the root.
Now we can move on to the core function. We’ll start this by defining key terms like our stack and the number of “remaining” values (initially k). We’ll also call our helper to get the initial stack.
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
...
}
Now we want to pop the top element from the stack, and pattern match it as requiring Some. If there are no more values, we’ll actually panic, because the problem constraints should mean that our stack is never empty. Unlike our helper, we actually will use a while loop here instead of more recursion:
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
...
}
panic!("k is larger than number of nodes");
}
Now the inside of the loop is simple, following what we’ve done in Haskell. If our remainder is 1, then we have found the correct node. We borrow the node from the RefCell and return its value. Otherwise we decrement the count and use our helper on the “right” child of the node we just popped. As usual, the RefCell wrapper means we need to borrow to get the right value from the TreeNode, and then we clone this child as we pass it to the helper.
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
if remaining == 1 {
return current.borrow().val;
}
remaining -= 1;
add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
}
panic!("k is larger than number of nodes");
}
And that’s it! Here’s the complete Rust solution:
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
add_left_nodes_to_stack(current.borrow().left.clone(), stack);
}
}
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
if remaining == 1 {
return current.borrow().val;
}
remaining -= 1;
add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
}
panic!("k is larger than number of nodes");
}
Conclusion
In-order traversal is a great pattern to commit to memory, as many different tree problems will require you to apply it. Hopefully the details with Rust’s RefCells are getting more familiar. Next week we’ll do one more problem with binary trees.
If you want to do some deep work with Haskell and binary trees, take a look at our Solve.hs course, where you’ll learn about many different data structures in Haskell, and get the chance to write a balanced binary search tree from scratch!
An Easy Problem Made Hard: Rust & Binary Trees
In last week’s article, we completed our look at Matrix-based problems. Today, we’re going to start considering another data structure: binary trees.
Binary trees are an extremely important structure in programming. Most notably, they are the underlying structure for ordered sets, that allow logarithmic time lookups and insertions. A “tree” is represented by “nodes”, where a node can be “null”, or else hold a value. If it holds a value, it then has a “left” child and a “right” child.
If you take our Solve.hs course, you’ll actually learn to implement an auto-balancing ordered tree set from scratch!
But for these next few articles, we’re going to explore some simple problems that involve binary trees. Today we’ll start with a problem that is very simple (rated as “Easy” by LeetCode), but still helps us grasp the core problem solving techniques behind binary trees. We’ll also encounter some interesting curveballs that Rust can throw at us when it comes to building more complex data structures.
The Problem
Our problem today is Invert Binary Tree. Given a binary tree, we want to return a new tree that is the mirror image of the input tree. For example, if we get this tree as an input:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
We should output a tree that looks like this:
- 45
/ \
50 32
/ / \
100 40 5
/ \
43 37
We see that 45 remains the root element, but instead of having 32 on the left and 50 on the right, these two elements are reversed on the next level. Then on the 3rd level, 40 and 5 remain children of 32, but they are also reversed from their prior orientations! This pattern continues all the way down on both sides.
The Algorithm
Binary trees (and in fact, all tree structures) lend themselves very well to recursive algorithms. We’ll use a very simple recursive algorithm here.
If the input node to our function is a “null” node, we will simply return “null” as the output. If the node has a value and children, then we’ll keep that value as the value for our output. However, we will recursively invert both of the child nodes.
Then we’ll take the inverted “left” child node and install it as the “right” child of our result. The inverted “right” child of the original input becomes the “left” child of the resulting node.
Haskell Solution
Haskell is a natural fit for this problem, since it relies so heavily on recursion. We start by defining a recursive TreeNode type. The canonical way to do this is with a Nil constructor as well as a recursive “value” constructor that actually holds the node’s value and refers to the left and right child. For this problem, we’ll just assume our tree holds Int values, so we won’t parameterize it.
data TreeNode = Nil | Node Int TreeNode TreeNode
deriving (Show, Eq)
Now solving our problem is easy! We pattern match on the input TreeNode. For our first case, we just return Nil for Nil.
invertTree :: TreeNode -> TreeNode
invertTree Nil = Nil
invertTree (Node x left right) = ...
For our second case, we use the same Int value for the value of our result. Then we recursively call invertTree on the right child, but put this in the place of the left child for our new result node. Likewise, we recursively invert the left child of our original and use this result for the right of our result.
invertTree :: TreeNode -> TreeNode
invertTree Nil = Nil
invertTree (Node x left right) = Node x (invertTree right) (invertTree left)
Very easy!
C++ Solution
In a non-functional language, it is still quite possible to solve this problem without recursion, but this is an occasion where we get very nice, clean code with recursion. As a rare treat, we’ll actually start with a C++ solution instead of jumping to Rust right away.
We would start by defining our TreeNode with a struct. Instead of relying on a separate Nil constructor, we use raw pointers for all our tree nodes. This means they can all potentially be nullptr.
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int v, TreeNode* l, TreeNode* r) : val(v), left(l), right(r) {};
};
TreeNode* invertTree(TreeNode* root) {
...
}
And our solution looks almost as easy as the Haskell solution:
TreeNode* invertTree(TreeNode* root) {
if (root == nullptr) {
return nullptr;
}
return new TreeNode(root->val, invertTree(root->right), invertTree(root->left));
}
Rust Solution
In Rust, it’s not quite as easy to work with recursive structures because of Rust’s memory system. In C++, we used raw pointers, which is fast but can cause significant problems if you aren’t careful (e.g. dereferencing null pointers, or memory leaks). Haskell uses garbage collected memory, which is slow but allows us to write simple code that won’t blow up in weird ways like C++.
Rust’s Memory System
Rust seeks to be fast like C++, while making it hard to do high-risk things like de-referencing a potentially null pointer, or leaking memory. It does this using the concept of “ownership”, and it’s a tricky concept to understand at first.
The ownership model makes it a bit harder for us to write a basic recursive data structures. To write a basic binary tree, you’d have to answer questions like:
- Who “owns” the child nodes?
- Can I write a function that accesses the child nodes without taking ownership of them? What if I have to modify them?
- Can I copy a reference to a child node without copying the entire sub-structure?
- Can I create a “new” tree that references part of another tree without copying?
Writing a TreeNode
Here’s the TreeNode struct provided by LeetCode for solving this problem. We can see that references to the nodes themselves are held within 3(!) wrapper types:
#[derive(Debug, PartialEq, Eq)]
pub struct TreeNode {
pub val: i32,
pub left: Option<Rc<RefCell<TreeNode>>>,
pub right: Option<Rc<RefCell<TreeNode>>>,
}
impl TreeNode {
#[inline]
pub fn new(val: i32) -> Self {
TreeNode {
val,
left: None,
right: None
}
}
}
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
}
From inside to outside, here’s what the three wrappers mean:
RefCellis a mutable, shareable container for data.Rcis a reference counting container. It automatically tracks how many references there are to theRefCell. The cell is de-allocated once this count is 0.Optionis Rust’s equivalent ofMaybe. This let’s us useNonefor an empty tree.
Rust normally only permits a single mutable reference, or multiple immutable references. So RefCell provides mechanics to get multiple mutable references. Let’s see how we can use these to write our invert_tree function.
Solving the Problem
We start by “cloning” the root input reference. Normally, “clone” means a deep copy, but in our case, this doesn’t actually copy the entire tree! Because it is wrapped in Rc, we’re just getting a new reference to the data in RefCell. We conditionally check if this is a Some wrapper. If it is None, we just return the root.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
...
}
return root;
}
If we didn’t “clone” root, the compiler would complain that we are “moving” the value in the condition, which would invalidate the prior reference to root.
Next, we use borrow_mut to get a mutable reference to the TreeNode inside the RefCell. This node_ref finally gives us something of type TreeNode so that we can work with the individual fields.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
...
}
return root;
}
Now for node_ref, both left and right have the full wrapper type Option<Rc<RefCell<TreeNode>>>. We want to recursively call invert_tree on these. Once again though, we have to call clone before passing these to the recursive function.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
// Recursively invert left and right subtrees
let left = invert_tree(node_ref.left.clone());
let right = invert_tree(node_ref.right.clone());
...
}
return root;
}
Now because we have a mutable reference in node_ref, we can install these new results as its left and right subtrees!
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
// Recursively invert left and right subtrees
let left = invert_tree(node_ref.left.clone());
let right = invert_tree(node_ref.right.clone());
// Swap them
node_ref.left = right;
node_ref.right = left;
}
return root;
}
And now we’re done! We don’t need a separate return statement inside the if. We have modified node_ref, which is still a reference to the same data as root holds. So returning root returns our modified tree.
Conclusion
Even though this was a simple problem with a basic recursive algorithm, we saw how Rust presented some interesting difficulties in applying this algorithm. Languages all make different tradeoffs, so every language has some example where it is difficult to write code that is simple in other languages. For Rust, this is recursive data structures. For Haskell though, it’s things like mutable arrays.
If you want to get some serious practice with binary trees, you should sign up for our problem solving course, Solve.hs. In Module 2, you’ll actually get to implement a balanced tree set from scratch, which is a very interesting and challenging problem that will stretch your knowledge!
Spiral Matrix: Another Matrix Layer Problem
In last week’s article, we learned how to rotate a 2D Matrix in place using Haskell’s mutable array mechanics. This taught us how to think about a Matrix in terms of layers, starting from the outside and moving in towards the center.
Today, we’ll study one more 2D Matrix problem that uses this layer-by-layer paradigm. For more practice dealing with multi-dimensional arrays, check out our Solve.hs course! In Module 2, you’ll study all kinds of different data structures in Haskell, including 2D Matrices (both mutable and immutable).
The Problem
Today’s problem is Spiral Matrix. In this problem, we receive a 2D Matrix, and we would like to return the elements of that matrix in a 1D list in “spiral order”. This ordering consists of starting from the top left and going right. When we hit the top right corner, we move down to the bottom. The we come back across the bottom row to the left, and then back up the top left. Then we continue this process on inner layers.
So, for example, let’s suppose we have this 4x4 matrix:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
This should return the following list:
[1,2,3,4,8,12,16,15,14,13,9,5,6,7,11,10]
At first glance, it seems like a lot of our layer-by-layer mechanics from last week will work again. All the numbers in the “first” layer come first, followed by the “second” layer, and so on. The trick though is that for this problem, we have to handle non-square matrices. So we can also have this matrix:
1 2 3 4
5 6 7 8
9 10 11 12
This should yield the list [1,2,3,4,8,12,11,10,9,5,6,7]. This isn’t a huge challenge, but we need a slightly different approach.
The Algorithm
We still want to generally move through the Matrix using a layer-by-layer approach. But instead of tracking the 4 corner points, we’ll just keep track of 4 “barriers”, imaginary lines dictating the “end” of each dimension (up/down/left/right) for us to scan. These barriers will be inclusive, meaning that they refer to the last valid row or column in that direction. We would call these “min row”, “min column”, “max row” and “max column”.
Now the general process for going through a layer will consist of 4 steps. Each step starts in a corner location and proceeds in one direction until the next corner is reached. Then, we can start again with the next layer.
The trick is the end condition. Because we can have rectangular matrices, the final layer can have a shape like 1 x n or n x 1, and this is a problem, because we wouldn’t need 4 steps. Even a square matrix of n x n with odd n would have a 1x1 as its final layer, and this is also a problem since it is unclear which “corner” this coordinate
Thus we have to handle these edge cases. However, they are easy to both detect and resolve. We know we are in such a case when “min row” and “max row” are equal, or if “min column” and “max column” are equal. Then to resolve the case, we just do one pass instead of 4, including both endpoints.
Rust Solution
For our Rust solution, let’s start by defining important terms, like we always do. For our terms, we’ll mainly be dealing with these 4 “barrier” values, the min and max for the current row and column. These are inclusive, so they are initially 0 and (length - 1). We also make a new vector to hold our result values.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
let mut result: Vec<i32> = Vec::new();
let mut minR: usize = 0;
let mut maxR: usize = matrix.len() - 1;
let mut minC: usize = 0;
let mut maxC: usize = matrix[0].len() - 1;
...
}
Now we want to write a while loop where each iteration processes a single layer. We’ll know we are out of layers if either “minimum” exceeds its corresponding “maximum”. Then we can start penciling in the different cases and phases of the loop. The edge cases occur when a minimum is exactly equal to its maximum. And for the normal case, we’ll do our 4-directional scanning.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
let mut result: Vec<i32> = Vec::new();
let mut minR: usize = 0;
let mut maxR: usize = matrix.len() - 1;
let mut minC: usize = 0;
let mut maxC: usize = matrix[0].len() - 1;
while (minR <= maxR && minC <= maxC) {
// Edge cases: single row or single column layers
if (minR == maxR) {
...
break;
} else if (minC == maxC) {
...
break;
}
// Scan TL->TR
...
// Scan TR->BR
...
// Scan BR->BL
...
// Scan BL->TL
...
minR += 1;
minC += 1;
maxR -= 1;
maxC -= 1;
}
return result;
}
Our “loop update” step comes at the end, when we increase both minimums, and decrease both maximums. This shows we are shrinking to the next layer.
Now we just have to fill in each case. All of these are scans through some portion of the matrix. The only trick is getting the ranges correct for each scan.
We’ll start with the edge cases. For a single row or column scan, we just need one loop. This loop should be inclusive across its dimension. Rust has a similar range syntax to Haskell, but it is less flexible. We can make a range inclusive by using = before the end element.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
...
while (minR <= maxR && minC <= maxC) {
// Edge cases: single row or single column layers
if (minR == maxR) {
for i in minC..=maxC {
result.push(matrix[minR][i]);
}
break;
} else if (minC == maxC) {
for i in minR..=maxR {
result.push(matrix[i][minC]);
}
break;
}
...
}
return result;
}
Now let’s fill in the other cases. Again, getting the right ranges is the most important factor. We also have to make sure we don’t mix up our dimensions or directions! We go right along minR, down along maxC, left along maxR, and then up along minC.
To represent a decreasing range, we have to make the corresponding incrementing range and then use .rev() to reverse it. This is a little inconvenient, giving up ranges that don’t look as nice, like for i in ((minC+1)..=maxC).rev(), because we want the decrementing range to include maxC but exclude minC.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
...
while (minR <= maxR && minC <= maxC) {
...
// Scan TL->TR
for i in minC..maxC {
result.push(matrix[minR][i]);
}
// Scan TR->BR
for i in minR..maxR {
result.push(matrix[i][maxC]);
}
// Scan BR->BL
for i in ((minC+1)..=maxC).rev() {
result.push(matrix[maxR][i]);
}
// Scan BL->TL
for i in ((minR+1)..=maxR).rev() {
result.push(matrix[i][minC]);
}
minR += 1;
minC += 1;
maxR -= 1;
maxC -= 1;
}
return result;
}
But once these cases are filled in, we’re done! Here’s the full solution:
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
let mut result: Vec<i32> = Vec::new();
let mut minR: usize = 0;
let mut maxR: usize = matrix.len() - 1;
let mut minC: usize = 0;
let mut maxC: usize = matrix[0].len() - 1;
while (minR <= maxR && minC <= maxC) {
// Edge cases: single row or single column layers
if (minR == maxR) {
for i in minC..=maxC {
result.push(matrix[minR][i]);
}
break;
} else if (minC == maxC) {
for i in minR..=maxR {
result.push(matrix[i][minC]);
}
break;
}
// Scan TL->TR
for i in minC..maxC {
result.push(matrix[minR][i]);
}
// Scan TR->BR
for i in minR..maxR {
result.push(matrix[i][maxC]);
}
// Scan BR->BL
for i in ((minC+1)..=maxC).rev() {
result.push(matrix[maxR][i]);
}
// Scan BL->TL
for i in ((minR+1)..=maxR).rev() {
result.push(matrix[i][minC]);
}
minR += 1;
minC += 1;
maxR -= 1;
maxC -= 1;
}
return result;
}
Haskell Solution
Now let’s write our Haskell solution. We don’t need any fancy mutation tricks here. Our function will just take a 2D array, and return a list of numbers.
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix = ...
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
Since we used a while loop in our Rust solution, it makes sense that we’ll want to use a raw recursive function that we’ll just call f. Our loop state was the 4 “barrier” values in each dimensions. We’ll also use an accumulator value for our result. Since our barriers are inclusive, we can simply use the bounds of our array for the initial values.
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f = undefined
This recursive function has 3 base cases. First, we have the “loop condition” we used in our Rust solution. If a min dimension value exceeds the max, we are done, and should return our accumulated result list.
Then the other two cases are our edge cases or having a single row or a single column for our final layer. In all these cases, we want to reverse the accumulated list. This means that when we put together our ranges, we want to be careful that they are in reverse order! So the edge cases should start at their max value and decrease to the min value (inclusive).
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix arr = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f minR minC maxR maxC acc
| minR > maxR || minC > maxC = reverse acc
| minR == maxR = reverse $ [arr A.! (minR, c) | c <- [maxC,maxC - 1..minC]] <> acc
| minC == maxC = reverse $ [arr A.! (r, minC) | r <- [maxR,maxR - 1..minR]] <> acc
| otherwise = ...
Now to fill in the otherwise case, we can do our 4 steps: going right from the top left, then going down from the top right, going left from the bottom right, and going up from the bottom left.
Like the edge cases, we make list comprehensions with ranges to pull the new numbers out of our input matrix. And again, we have to make sure we accumulate them in reverse order. Then we append all of them to the existing accumulation.
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix arr = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f minR minC maxR maxC acc
...
| otherwise =
let goRights = [arr A.! (minR, c) | c <- [maxC - 1, maxC - 2..minC]]
goDowns = [arr A.! (r, maxC) | r <- [maxR - 1, maxR - 2..minR]]
goLefts = [arr A.! (maxR, c) | c <- [minC + 1..maxC]]
goUps = [arr A.! (r, minC) | r <- [minR+1..maxR]]
acc' = goUps <> goLefts <> goDowns <> goRights <> acc
in f (minR + 1) (minC + 1) (maxR - 1) (maxC - 1) acc'
We conclude by making our recursive call with the updated result list, and shifting the barriers to get to the next layer.
Here’s the full implementation:
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix arr = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f minR minC maxR maxC acc
| minR > maxR || minC > maxC = reverse acc
| minR == maxR = reverse $ [arr A.! (minR, c) | c <- [maxC,maxC - 1..minC]] <> acc
| minC == maxC = reverse $ [arr A.! (r, minC) | r <- [maxR,maxR - 1..minR]] <> acc
| otherwise =
let goRights = [arr A.! (minR, c) | c <- [maxC - 1, maxC - 2..minC]]
goDowns = [arr A.! (r, maxC) | r <- [maxR - 1, maxR - 2..minR]]
goLefts = [arr A.! (maxR, c) | c <- [minC + 1..maxC]]
goUps = [arr A.! (r, minC) | r <- [minR+1..maxR]]
acc' = goUps <> goLefts <> goDowns <> goRights <> acc
in f (minR + 1) (minC + 1) (maxR - 1) (maxC - 1) acc'
Conclusion
This is the last matrix-based problem we’ll study for now. Next time we’ll start considering some tree-based problems. If you sign up for our Solve.hs course, you’ll learn about both of these kinds of data structures in Module 2. You’ll implement a tree set from scratch, and you’ll get lots of practice working with these and many other structures. So enroll today!
Image Rotation: Mutable Arrays in Haskell
In last week’s article, we took our first step into working with multi-dimensional arrays. Today, we’ll be working with another Matrix problem that involves in-place mutation. The Haskell solution uses the MArray interface, which takes us out of our usual
The MArray interface is a little tricky to work with. If you want a full overview of the API, you should sign up for our Solve.hs course, where we cover mutable arrays in module 2!
The Problem
Today’s problem is Rotate Image. We’re going to take a 2D Matrix of integer values as our input and rotate the matrix 90 degrees clockwise. We must accomplish this in place, modifying the input value without allocating a new Matrix. The input matrix is always “square” (n x n).
Here are a few examples to illustrate the idea. We can start with a 2x2 matrix:
1 2 | 3 1
3 4 | 4 2
The 4x4 rotation makes it more clear that we’re not just moving numbers one space over. Each corner element will go to a new corner. You can also see how the inside of the matrix is also rotating:
1 2 3 4 | 13 9 5 1
5 6 7 8 | 14 10 6 2
9 10 11 12 | 15 11 7 3
13 14 15 16 | 16 12 8 4
The 3x3 version shows how with an odd number of rows and columns, the inner most number will stand still.
1 2 3 | 7 4 1
4 5 6 | 8 5 2
7 8 9 | 9 6 3
The Algorithm
While this problem might be a little intimidating at first, we just have to break it into sufficiently small and repeatable pieces. The core step is that we swap four numbers into each other’s positions. It’s easy to see, for example, that the four corners always trade places with one another (1, 4, 13, 16 in the 4x4 example).
What’s important is seeing the other sets of 4. We move clockwise to get the next 4 values:
- The value to the right of the top left corner
- The value below the top right corner
- The value to the left of the bottom right corner
- The value above the bottom left corner.
So in the 4x4 example, these would be 2, 8, 15, 9. Then another group is 3, 12, 14, 15.
Those 3 groups are all the rotations we need for the “outer layer”. Then we move to the next layer, where we have a single group of 4: 6, 7, 10, 11.
This should tell us that we have a 3-step process:
- Loop through each layer of the matrix
- Identify all groups of 4 in this layer
- Rotate each group of 4
It helps to put a count on the size of each of these loops. For an n x n matrix, the number of layers to rotate is n / 2, rounded down, because the inner-most layer needs no rotation in an odd-sized matrix.
Then for a layer spanning from column c1 to c2, the number of groups in that layer is just c2 - c1. So for the first layer in a 4x4, we span columns 0 to 3, and there are 3 groups of 4. In the inner layer, we span columns 1 to 2, so there is only 1 group of 4.
Rust Solution
As is typical, we’ll see more of a loop structure in our Rust code, and a recursive version of this solution in Haskell. We’ll also start by defining various terms we’ll use. There are multiple ways to approach the details of this problem, but we’ll take an approach that maximizes the clarity of our inner loops.
We’ll define each “layer” using the four corner coordinates of that layer. So for an n x n matrix, these are (0,0), (0, n - 1), (n - 1, n - 1), (n - 1, 0). After we finish looping through a layer, we can simply increment/decrement each of these values as appropriate to get the corner coordinates of the next layer ((1,1), (1, n - 2), etc.).
So let’s start our solution by defining the 8 mutable values for these 4 corners. Each corner (top/left/bottom/right) has a row R and column C value.
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
let n = matrix.len();
let numLayers = n / 2;
let mut topLeftR = 0;
let mut topLeftC = 0;
let mut topRightR = 0;
let mut topRightC = n - 1;
let mut bottomRightR = n - 1;
let mut bottomRightC = n - 1;
let mut bottomLeftR = n - 1;
let mut bottomLeftC = 0;
...
}
It would be possible to solve the problem without these values, determining coordinates using the layer number. But I’ve found this to be somewhat more error prone, since we’re constantly adding and subtracting from different coordinates in different combinations. We get the number of layers from n / 2.
Now let’s frame the outer loop. We conclude the loop by modifying each coordinate point. Then at the beginning of the loop, we can determine the number of “groups” for the layer by taking the difference between the left and right column coordinates.
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
...
for i in 0..numLayers {
let numGroups = topRightC - topLeftC;
for j in 0..numGroups {
...
}
topLeftR += 1;
topLeftC += 1;
topRightR += 1;
topRightC -= 1;
bottomRightR -= 1;
bottomRightC -= 1;
bottomLeftR -= 1;
bottomLeftC += 1;
}
}
Now we just need the logic for rotating a single group of 4 points. This is a 5-step process:
- Save top left value as
temp - Move bottom left to top left
- Move bottom right to bottom left
- Move top right to bottom right
- Move
temp(original top left) to top right
Unlike the layer number, we’ll use the group variable j for arithmetic here. When you’re writing this yourself, it’s important to go slowly to make sure you’re using the right corner values and adding/subtracting j from the correct dimension.
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
...
for i in 0..numLayers {
let numGroups = topRightC - topLeftC;
for j in 0..numGroups {
let temp = matrix[topLeftR][topLeftC + j];
matrix[topLeftR][topLeftC + j] = matrix[bottomLeftR - j][bottomLeftC];
matrix[bottomLeftR - j][bottomLeftC] = matrix[bottomRightR][bottomRightC - j];
matrix[bottomRightR][bottomRightC - j] = matrix[topRightR + j][topRightC];
matrix[topRightR + j][topRightC] = temp;
}
... // (update corners)
}
}
And then we’re done! We don’t actually need to return a value since we’re just modifying the input in place. Here’s the full solution:
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
let n = matrix.len();
let numLayers = n / 2;
let mut topLeftR = 0;
let mut topLeftC = 0;
let mut topRightR = 0;
let mut topRightC = n - 1;
let mut bottomRightR = n - 1;
let mut bottomRightC = n - 1;
let mut bottomLeftR = n - 1;
let mut bottomLeftC = 0;
for i in 0..numLayers {
let numGroups = topRightC - topLeftC;
for j in 0..numGroups {
let temp = matrix[topLeftR][topLeftC + j];
matrix[topLeftR][topLeftC + j] = matrix[bottomLeftR - j][bottomLeftC];
matrix[bottomLeftR - j][bottomLeftC] = matrix[bottomRightR][bottomRightC - j];
matrix[bottomRightR][bottomRightC - j] = matrix[topRightR + j][topRightC];
matrix[topRightR + j][topRightC] = temp;
}
topLeftR += 1;
topLeftC += 1;
topRightR += 1;
topRightC -= 1;
bottomRightR -= 1;
bottomRightC -= 1;
bottomLeftR -= 1;
bottomLeftC += 1;
}
}
Haskell Solution
This is an interesting problem to solve in Haskell because Haskell is a generally immutable language. Unlike Rust, we can’t make values mutable just by putting the keyword mut in front of them.
With arrays, we can modify them in place though using the MArray monad class. We won’t go through all the details of the interface in this article (you can learn about all that in Solve.hs Module 2). But we’ll start with the type signature:
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
This tells us we are taking a mutable array, where the array type is polymorphic but tied to the monad m. For example, IOArray would work with the IO monad. We don’t return anything, because we’re modifying our input.
We still begin our function by defining terms, but now we need to use monadic actions to retrieve even the bounds our our array.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
...
Our algorithm has two loop levels. The outer loop goes through the different layers of the matrix. The inner layer goes through each group of 4 within the layer. In Haskell, both of these loops are recursive, monadic functions. Our Rust loops treat the four corner points of the layer as stateful values, so these need to be inputs to our recursive functions. In addition, each function will take the layer/group number as an input.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
...
where
rotateLayer tl@(tlR, tlC) tr@(trR, trC) br@(brR, brC) bl@(blR, blC) n = ...
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = ...
Now we just have to fill in these functions. For rotateLayer, we use the “layer number” parameter as a countdown. Once it reaches 0, we’ll be done. We just need to determine the number of groups in this layer using the column difference of left and right. Then we’ll call rotateGroup for each group.
We make the first call to rotateLayer with numLayers and the original corners, coming from our dimensions. When we recurse, we add/subtract 1 from the corner dimensions, and subtract 1 from the layer number.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
rotateLayer (minR, minC) (minR, maxC) (maxR, maxC) (maxR, minC) numLayers
where
rotateLayer _ _ _ _ 0 = return ()
rotateLayer tl@(tlR, tlC) tr@(trR, trC) br@(brR, brC) bl@(blR, blC) n = do
let numGroups = ([0..(trC - tlC - 1)] :: [Int])
forM_ numGroups (rotateGroup tl tr br bl)
rotateLayer (tlR + 1, tlC + 1) (trR + 1, trC - 1) (brR - 1, brC - 1) (blR - 1, blC + 1) (n - 1)
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = ...
And how do we rotate a group? We use the same five steps we took in Rust. We save the top left as temp and then move the values around. We use the monadic functions readArray and writeArray to perform these actions in place on our Matrix.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
...
where
...
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = do
temp <- readArray arr (tlR, tlC + j)
readArray arr (blR - j, blC) >>= writeArray arr (tlR, tlC + j)
readArray arr (brR, brC - j) >>= writeArray arr (blR - j, blC)
readArray arr (trR + j, trC) >>= writeArray arr (brR, brC - j)
writeArray arr (trR + j, trC) temp
Here’s the full implementation:
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
rotateLayer (minR, minC) (minR, maxC) (maxR, maxC) (maxR, minC) numLayers
where
rotateLayer _ _ _ _ 0 = return ()
rotateLayer tl@(tlR, tlC) tr@(trR, trC) br@(brR, brC) bl@(blR, blC) n = do
let numGroups = ([0..(trC - tlC - 1)] :: [Int])
forM_ numGroups (rotateGroup tl tr br bl)
rotateLayer (tlR + 1, tlC + 1) (trR + 1, trC - 1) (brR - 1, brC - 1) (blR - 1, blC + 1) (n - 1)
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = do
temp <- readArray arr (tlR, tlC + j)
readArray arr (blR - j, blC) >>= writeArray arr (tlR, tlC + j)
readArray arr (brR, brC - j) >>= writeArray arr (blR - j, blC)
readArray arr (trR + j, trC) >>= writeArray arr (brR, brC - j)
writeArray arr (trR + j, trC) temp
Conclusion
We’ve got one more Matrix problem to solve next time, and then we’ll move on to some other data structures. To learn more about using Data Structures and Algorithms in Haskell, you take our Solve.hs course. You’ll get the chance to write a number of data structures from scratch, and you’ll get plenty of practice working with them and using them in algorithms!
Binary Search in a 2D Matrix
In our problem last week, we covered a complex problem that used a binary search. Today, we’ll apply binary search again to solidify our understanding of it. This time, instead of extra algorithmic complexity, we’ll start adding some data structure complexity. We’ll be working with a 2D Matrix instead of basic arrays.
To learn more about data structures and algorithms in Haskell, you should take a look at our Solve.hs course! In particular, you’ll cover multi-dimensional arrays in module 2, and you’ll learn how to write algorithms in Haskell in module 3!
The Problem
Today’s problem is Search a 2D Matrix, and the description is straightforward. We’re given a 2D m x n matrix, as well as a target number. We have to return a boolean for whether or not that number is in the Matrix.
This is trivial with a simple scan, but we have an additional constraint that lets us solve the problem faster. The matrix is essentially ordered. Each row is non-decreasing, and the first element of each successive row is no smaller than the last element of the preceding row.
This allows us to get a solution that is O(log(n + m)), a considerable improvement over a linear scan.
The Algorithm
The algorithm is simple as well. We’ll do two binary searches. First, we’ll search over the rows to identify the last row which could contain the element. Then we’ll do a binary search of that row to see if the element is present or not.
We’ll have a slightly different form to our searches compared to last time. In last week’s problem, we knew we had to find a valid index for our search. Now, we may find that no valid index exists.
So we’ll structure our search interval in a semi-open fashion. The first index in our search interval is inclusive, meaning that it could still be a valid index. The second index is exclusive, meaning it is the lowest index that we consider invalid.
In mathematical notation, we would represent such an interval with a square bracket on the left and a parenthesis on the right. So if that interval is [0, 4), then 0, 1, 2, 3 are valid values. The interval [2,2) would be considered empty, with no valid values. We’ll see how we apply this idea in practice.
Rust Solution
We don’t have that many terms to define at the start of this solution. We’ll save the size of both dimensions, and then prepare ourselves for the first binary search by assigning low as 0 (the first potential “valid” answer), hi as m (the lowest “invalid” answer), and creating our output rowWithTarget value. For this, we also assign m, an invalid value. If we fail to re-assign rowWithTarget in our binary search, we want it assigned to an easily testable invalid value.
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
let m = matrix.len();
let n = matrix[0].len();
let mut low = 0;
let mut hi = m;
let mut rowWithTarget = m;
...
}
Now we write our first binary search, looking for a row that could contain our target value. We maintain the typical pattern of binary search, using the loop while (low < hi) and assigning mid = (low + hi) / 2.
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
...
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[mid][0] > target) {
hi = mid;
} else if (matrix[mid][n - 1] < target) {
low = mid + 1;
} else {
rowWithTarget = mid;
break;
}
}
if (rowWithTarget >= m) {
return false;
}
...
}
If the first element of the row is too large, we know that mid is “invalid”, so we can assign it as hi and continue. If the last element is too small, then we reassign low as mid + 1, as we want low to still be a potentially valid value.
Otherwise, we have found a potential row, so we assign rowWithTarget and break. If, after this search, rowWithTarget has the “invalid” value of m, we can return false, as there are no valid values.
Now we just do the same thing over again, but within rowWithTarget! We reassign low and hi (as n this time) to reset the while loop. And now our comparisons will look at the specific value matrix[rowWithTarget][mid].
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
...
low = 0;
hi = n;
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[rowWithTarget][mid] > target) {
hi = mid;
} else if (matrix[rowWithTarget][mid] < target) {
low = mid + 1;
} else {
return true;
}
}
return false;
}
Again, we follow the same pattern of re-assigning low and hi. If we don’t hit the return true case in the loop, we’ll end up with return false at the end, because we haven’t found the target.
Here’s the full solution:
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
let m = matrix.len();
let n = matrix[0].len();
let mut low = 0;
let mut hi = m;
let mut rowWithTarget = m;
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[mid][0] > target) {
hi = mid;
} else if (matrix[mid][n - 1] < target) {
low = mid + 1;
} else {
rowWithTarget = mid;
break;
}
}
if (rowWithTarget >= m) {
return false;
}
low = 0;
hi = n;
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[rowWithTarget][mid] > target) {
hi = mid;
} else if (matrix[rowWithTarget][mid] < target) {
low = mid + 1;
} else {
return true;
}
}
return false;
}
Haskell Solution
In our Haskell solution, the main difference of course will be using recursion for the binary search. However, we’ll also change up the data structure a bit. In the Rust framing of the problem, we had a vector of vectors of values. We could do this in Haskell, but we could also use Array (Int, Int) Int. This lets us map row/column pairs to numbers in a more intuitive way.
import qualified Data.Array as A
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = ...
where
((minR, minC), (maxR, maxC)) = A.bounds matrix
Another unique feature of arrays is that the bounds don’t have to start from 0. We can have totally custom bounding dimensions for our rows and columns. So instead of using m and n, we’ll need to use the min and max of the row and column dimensions.
So now let’s define our first binary search, looking for the valid row. As we did last week, the input to our function will be two Int values, for the low and hi. As in our Rust solution we’ll access the first and last element of the row defined by the “middle” of low and hi, and compare them against the target. We make recursive calls to searchRow if the row isn’t valid.
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = result
where
((minR, minC), (maxR, maxC)) = A.bounds matrix
searchRow :: (Int, Int) -> Int
searchRow (low, hi) = if low >= hi then maxR + 1 else
let mid = (low + hi) `quot` 2
firstInRow = matrix A.! (mid, minC)
lastInRow = matrix A.! (mid, maxC)
in if firstInRow > target
then searchRow (low, mid)
else if lastInRow < target
then searchRow (mid + 1, hi)
else mid
rowWithTarget = searchRow (minR, maxR + 1)
result = rowWithTarget <= maxR && ...
Instead of m, we have maxR + 1, which we use as the initial hi value, as well as a return value in the base case where low meets hi. We can return a result of False if rowWithTarget does not come back with a value smaller than maxR.
Now for our second search, we follow the same pattern, but now we’re returning a boolean. The base case returns False, and we return True if we find the value in rowWithTarget at position mid. Here’s what that looks like:
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = result
where
...
rowWithTarget = searchRow (minR, maxR + 1)
searchCol :: (Int, Int) -> Bool
searchCol (low, hi) = low < hi &&
let mid = (low + hi) `quot` 2
val = matrix A.! (rowWithTarget, mid)
in if val > target
then searchCol (low, mid)
else if val < target
then searchCol (mid + 1, hi)
else True
result = rowWithTarget <= maxR && searchCol (minC, maxC + 1)
You’ll see we now use the outcome of searchCol for result. And this completes our solution! Here’s the full code:
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = result
where
((minR, minC), (maxR, maxC)) = A.bounds matrix
searchRow :: (Int, Int) -> Int
searchRow (low, hi) = if low >= hi then maxR + 1 else
let mid = (low + hi) `quot` 2
firstInRow = matrix A.! (mid, minC)
lastInRow = matrix A.! (mid, maxC)
in if firstInRow > target
then searchRow (low, mid)
else if lastInRow < target
then searchRow (mid + 1, hi)
else mid
rowWithTarget = searchRow (minR, maxR + 1)
searchCol :: (Int, Int) -> Bool
searchCol (low, hi) = low < hi &&
let mid = (low + hi) `quot` 2
val = matrix A.! (rowWithTarget, mid)
in if val > target
then searchCol (low, mid)
else if val < target
then searchCol (mid + 1, hi)
else True
result = rowWithTarget <= maxR && searchCol (minC, maxC + 1)
Conclusion
Next week, we’ll stay on the subject of 2D matrices, but we’ll learn about array mutation. This is a very tricky subject in Haskell, so make sure to come back for that article!
To learn how these data structures work in Haskell, read about Solve.hs, our Haskell Data Structures & Algorithms course!
Binary Search in Haskell and Rust
This week we’ll be continuing our series of problem solving in Haskell and Rust. But now we’re going to start moving beyond the terrain of “basic” problem solving techniques with strings, lists and arrays, and start moving in the direction of more complicated data structures and algorithms. Today we’ll explore a problem that is still array-based, but uses a tricky algorithm that involves binary search!
You’ll learn more about Data Structures and Algorithms in our Solve.hs course! The last 7 weeks or so of blog articles have focused on the types of problems you’ll see in Module 1 of that course, but now we’re going to start encountering ideas from Modules 2 & 3, which look extensively at essential data structures and algorithms you need to know for problem solving.
The Problem
Today’s problem is median of two sorted arrays. In this problem, we receive two arrays of numbers as input, each of them in sorted order. The arrays are not necessarily of the same size. Our job is to find the median of the cumulative set of numbers.
Now there’s a conceptually easy approach to this. We could simply scan through the two arrays, keeping track of one index for each one. We would increase the index for whichever number is currently smaller, and stop once we have passed by half of the total numbers. This approach is essentially the “merge” part of merge sort, and it would take O(n) time, since we are scanning half of all the numbers.
However, there’s a faster approach! And if you are asked this question in an interview for anything other than a very junior position, your interviewer will expect you to find this faster approach. Because the arrays are sorted, we can leverage binary search to find the median in O(log n) time. The approach isn’t easy to see though! Let’s go over the algorithm before we get into any code.
The Algorithm
This algorithm is a little tricky to follow (this problem is rated as “hard” on LeetCode). So we’re going to treat this a bit like a mathematical proof, and begin by defining useful terms. Then it will be easy to describe the coding concepts behind the algorithm.
Defining our Terms
Our input consists of 2 arrays, arr1 and arr2 with potentially different sizes n and m, respectively. Without loss of generality, let arr1 be the “shorter” array, so that n <= m. We’ll also define t as the total number of elements, n + m.
It is worthwhile to note right off the bat that if t is odd, then a single element from one of the two lists will be the median. If t is even, then we will average two elements together. Even though we won’t actually create the final merged array, we can imagine that it consists of 3 parts:
- The “prior” portion - all numbers before the median element(s)
- The median element(s), either 1 or 2.
- The “latter” portion - all numbers after the median element(s)
The total number of elements in the “prior” portion will end up being (t - 1) / 2, bearing in mind how integer division works. For example, whether t is 15 or 16, we get 7 elements in the “prior” portion. We’ll use p for this number.
Finally, let’s imagine p1, the number of elements from arr1 that will end up in the prior portion. If we know p1, then p2, the number of elements from arr2 in the prior portion is fixed, because p1 + p2 = p. We can then think of p1 as an index into arr1, the index of the first element that is not in the prior portion. The only trick is that this index could be n indicating that all elements of arr1 are in the prior portion.
Getting the Final Answer from our Terms
If we have the “correct” values for p1 and p2, then finding the median is easy. If t is odd, then the lower number between arr1[p1] and arr2[p2] is the median. If t is even, then we average the two smallest numbers among (arr1[p1], arr2[p2], arr1[p1 + 1], arr2[p2 + 1]).
So we’ve reduced this problem to a matter of finding p1, since p2 can be easily derived from it. How do we know we have the “correct” value for p1, and how do we search for it efficiently?
Solving for p1
The answer is that we will conduct a binary search on arr1 in order to find the correct value of p1. For any particular choice of p1, we determine the corresponding value of p2. Then we make two comparisons:
- Compare
arr1[p1 - 1]toarr2[p2] - Compare
arr2[p2 - 1]toarr1[p1]
If both comparisons are less-than-or-equals, then our two p values are correct! The slices arr1[0..p1-1] and arr2[0..p2-1] always constitute a total of p values, and if these values are smaller than arr1[p1] and arr2[p2], then they constitute the entire “prior” set.
If, on the other hand, the first comparison yields “greater than”, then we have too many values for arr1 in our prior set. This means we need to recursively do the binary search on the left side of arr1, since p1 should be smaller.
Then if the second comparison yields “greater than”, we have too few values from arr1 in the “prior” set. We should increase p1 by searching the right half of our array.
This provides a complete algorithm for us to follow!
Rust Implementation
Our algorithm description was quite long, but the advantage of having so many details is that the code starts to write itself! We’ll start with our Rust implementation. Stage 1 is to define all of the terms using our input values. We want to define our sizes and array references generically so that arr1 is the shorter array:
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let mut n = nums1.len();
let mut m = nums2.len();
let mut arr1: &Vec<i32> = &nums1;
let mut arr2: &Vec<i32> = &nums2;
if (m < n) {
n = nums2.len();
m = nums1.len();
arr1 = &nums2;
arr2 = &nums1;
}
let t = n + m;
let p: usize = (t - 1) / 2;
...
}
Anatomy of a Binary Search
The next stage is the binary search, so we can find p1 and p2. Now a binary search is a particular kind of loop pattern. Like many of the loop patterns we worked with in the previous weeks, we can express it recursively, or with a loop construct like for or while. We’ll start with a while loop solution for Rust, and then show the recursive solution with Haskell.
All loops maintain some kind of state. For a binary search, the primary state is the two endpoints representing our “interval of interest”. This starts out as the entire interval, and shrinks by half each time until we’ve narrowed to a single element (or no elements). We’ll represent these with interval end points with low and hi. Our loop concludes once low is as large as hi.
let mut low = 0;
// Use the shorter array size!
let mut hi = n;
while (low < hi) {
...
}
In our particular case, we are also trying to determine the values for p1 and p2. Each time we specify an interval, we’ll see if the midpoint of that interval (between low and hi) is the correct value of p1:
...
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
...
}
Now we evaluate this p1 value using the two conditions we specified in our algorithm. These are self-explanatory, except we do need to cover some edge cases where one of our values is at the edge of the array bounds.
For example, if p1 is 0, the first condition is always “true”. If this condition is negated, this means we want fewer elements from arr1, but this is impossible if p1 is 0.
...
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
if (cond1 && cond2) {
break;
} else if (!cond1) {
p1 -= 1;
hi = p1;
} else {
p1 += 1;
low = p1;
}
}
p2 = p - p1;
...
If both conditions are met, you’ll see we break, because we’ve found the right value for p1! Otherwise, we know p1 is invalid. This means we want to exclude the existing p1 value from further consideration by changing either low or hi to remove it from the interval of interest.
So if cond1 is false, hi becomes p1 - 1, and if cond2 is false, it becomes p1 + 1. In both cases, we also modify p1 itself first so that our loop does not conclude with p1 in an invalid location.
Getting the Final Answer
Now that we have p1 and p2, we have to do a couple final tricks to get the final answer. We want to get the first “smaller” value between arr1[p1] and arr2[p2]. But we have to handle the edge case where p1 might be n AND we want to increment the index for the array we take. Note that p2 cannot be out of bounds right now!
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
If the total number of elements is odd, we can simply return this number (converting to a float). However, in the even case we need one more number to take an average. So we’ll compare the values at the indices again, but now accounting that either (but not both) could be out of bounds.
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
if (t % 2 == 0) {
if (p1 >= n) {
median += arr2[p2];
} else if (p2 >= m) {
median += arr1[p1];
} else {
median += cmp::min(arr1[p1], arr2[p2]);
}
let medianF: f64 = median.into();
return medianF / 2.0;
} else {
return median.into();
}
Here’s the complete implementation:
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let mut n = nums1.len();
let mut m = nums2.len();
let mut arr1: &Vec<i32> = &nums1;
let mut arr2: &Vec<i32> = &nums2;
if (m < n) {
n = nums2.len();
m = nums1.len();
arr1 = &nums2;
arr2 = &nums1;
}
let t = n + m;
let p: usize = (t - 1) / 2;
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
if (cond1 && cond2) {
break;
} else if (!cond1) {
p1 -= 1;
hi = p1;
} else {
p1 += 1;
low = p1;
}
}
p2 = p - p1;
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
if (t % 2 == 0) {
if (p1 >= n) {
median += arr2[p2];
} else if (p2 >= m) {
median += arr1[p1];
} else {
median += cmp::min(arr1[p1], arr2[p2]);
}
let medianF: f64 = median.into();
return medianF / 2.0;
} else {
return median.into();
}
}
Haskell Implementation
Now let’s examine the Haskell implementation. Unlike the LeetCode version, we’ll just assume our inputs are Double already instead of doing a conversion. Once again, we start by defining the terms:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
n' = V.length input1
m' = V.length input2
t = n' + m'
p = (t - 1) `quot` 2
(n, m, arr1, arr2) = if V.length input1 <= V.length input2
then (n', m', input1, input2) else (m', n', input2, input1)
...
Now we’ll implement the binary search, this time doing a recursive function. We’ll do this in two parts, starting with a helper function. This helper function will simply tell us if a particular index is correct for p1. The trick though is that we’ll return an Ordering instead of just a Bool:
-- data Ordering = LT | EQ | GT
f :: Int -> Ordering
This lets us signal 3 possibilities. If we return EQ, this means the index is valid. If we return LT, this will mean we want fewer values from arr1. And then GT means we want more values from arr1.
With this framing it’s easy to see the implementation of this helper now. We determine the appropriate p2, figure out our two conditions, and return the value for each condition:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
f :: Int -> Ordering
f pi1 =
let pi2 = p - pi1
cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
in if cond1 && cond2 then EQ else if (not cond1) then LT else GT
Now applying we can use this in a recursive binary search. The binary search tracks two pieces of state for our interval ((Int, Int)), and it will return the correct value for p1. The implementation applies the base case (return low if low >= hi), determines the midpoint, calls our helper, and then recurses appropriately based on the helper result.
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
f :: Int -> Ordering
f pi1 = ...
search :: (Int, Int) -> Int
search (low, hi) = if low >= hi then low else
let mid = (low + hi) `quot` 2
in case f mid of
EQ -> mid
LT -> search (low, mid - 1)
GT -> search (mid + 1, hi)
p1 = search (0, n)
p2 = p - p1
...
For the final part of the problem, we’ll define a helper. Given p1 and p2, it will emit the “lower” value between the two indices in the array (accounting for edge cases) as well as the two new indices (since one will increment).
This is a matter of lazily defining the “next” value for each array, the “end” condition of each array, and the “result” if that array’s value is chosen:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
findNext pi1 pi2 =
let next1 = arr1 V.! pi1
next2 = arr2 V.! pi2
end1 = pi1 >= n
end2 = pi2 >= m
res1 = (next1, pi1 + 1, pi2)
res2 = (next2, pi1, pi2 + 1)
in if end1 then res2
else if end2 then res1
else if next1 <= next2 then res1 else res2
Now we just apply this either once or twice to get our result!
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
where
...
tIsEven = even t
(median1, nextP1, nextP2) = findNext p1 p2
(median2, _, _) = findNext nextP1 nextP2
result = if tIsEven
then (median1 + median2) / 2.0
else median1
Here’s the complete implementation:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
where
n' = V.length input1
m' = V.length input2
t = n' + m'
p = (t - 1) `quot` 2
(n, m, arr1, arr2) = if V.length input1 <= V.length input2
then (n', m', input1, input2) else (m', n', input2, input1)
-- Evaluate the index in arr1
-- If this does in indicate the index can be part of a median, return EQ
-- If it indicates we need to move left in shortArr, return LT
-- If it indicates we need to move right in shortArr, return GT
-- Precondition: p1 <= n
f :: Int -> Ordering
f pi1 =
let pi2 = p - pi1
cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
in if cond1 && cond2 then EQ else if (not cond1) then LT else GT
search :: (Int, Int) -> Int
search (low, hi) = if low >= hi then low else
let mid = (low + hi) `quot` 2
in case f mid of
EQ -> mid
LT -> search (low, mid - 1)
GT -> search (mid + 1, hi)
findNext pi1 pi2 =
let next1 = arr1 V.! pi1
next2 = arr2 V.! pi2
end1 = pi1 >= n
end2 = pi2 >= m
res1 = (next1, pi1 + 1, pi2)
res2 = (next2, pi1, pi2 + 1)
in if end1 then res2
else if end2 then res1
else if next1 <= next2 then res1 else res2
p1 = search (0, n)
p2 = p - p1
tIsEven = even t
(median1, nextP1, nextP2) = findNext p1 p2
(median2, _, _) = findNext nextP1 nextP2
result = if tIsEven
then (median1 + median2) / 2.0
else median1
Conclusion
If you want to learn more about these kinds of problem solving techniques, you should take our course Solve.hs! In the coming weeks, we’ll see more problems related to data structures and algorithms, which are covered extensively in Modules 2 and 3 of that course!
Buffer & Save with a Challenging Example
Welcome back to our series comparing LeetCode problems in Haskell and Rust. Today we’ll learn a new paradigm that I call “Buffer and Save”. This will also be the hardest problem we’ve done so far! The core loop structure isn’t that hard, but there are a couple layers of tricks to massage our data to get the final answer.
This will be the last problem we do that focuses strictly on string and list manipulation. The next set of problems we do will all rely on more advanced data structures or algorithmic ideas.
For more complete practice on problem solving in Haskell, check out Solve.hs, our newest course. This course will teach you everything you need to know about problem solving, data structures, and algorithms in Haskell. You’ll get loads of practice building structures and algorithms from scratch, which is very important for understanding and remembering how they work.
The Problem
Today’s problem is Text Justification. The idea here is that we are taking a list of words and a “maximum width” and printing out the words grouped into equal-width lines that are evenly spaced. Here’s an example input and output:
Example Input (list of 9 strings):
[“Study”, “Haskell”, “with”, “us”, “every”, “Monday”, “Morning”, “for”, “fun”]
Max Width: 16
Output (list of 4 strings):
“Study Haskell”
“with us every”
“Monday Morning”
“for fun ”
There are a few notable rules, constraints, and edge cases. Here’s a list to sumarize them:
- There is at least one word
- No word is larger than the max width
- All output strings must have max width as their length (including spaces)
- The first word of every line is set to the left
- The last line always has 1 space between words, and then enough spaces after the last word to read the max width.
- All other lines with multiple words will align the final word all the way to the right
- The spaces in non-final lines are distributed as evenly as possible, but extra spaces go between words to the left.
The final point is potentially the trickiest to understand. Consider the second line above, with us every. The max width is 16, and we have 3 words with a total of 11 characters. This leaves us 5 spaces. Having 3 words means 2 blanks, so the “left” blank gets 3 spaces and the “right” blank gets 2 spaces.
If you had a line with 5 words, a max width of 30, and 16 characters, you would place 4 spaces in the left two blanks, and 3 spaces in the right two blanks. The relative length of the words does not matter.
Words in Line: [“A”, “good”, “day”, “to”, “endure”]
Output Line:
“A good day to endure”
The Algorithm
As mentioned above, our main algorithmic idea could be called “buffer and save”. We’ve been defining all of our loops based on the state we must maintain between iterations of the loop. The buffer and save approach highlights two pieces of state for us:
- The strings we’ve accumulated for our answer so far (the “result”)
- A buffer of the strings in the “current” line we’re building.
So we’ll loop through the input words one at a time. We’ll consider if the next word can be added to the “current” line. If it would cause our current line to exceed the maximum width, we’ll “save” our current line and write it out to the “result” list, adding the required spaces.
To help our calculations, we’ll also include two other pieces of state in our loop:
- The number of characters in our “current” line
- The number of words in our “current” line
Finally, there’s the question of how to construct each output line. Combining the math with list-mechanics is a little tricky. But the central idea consists of 4 simple steps:
- Find the number of spaces (subtract number of characters from max width)
- Divide the number of spaces by the number of “blanks” (number of words - 1)
- The quotient is the “base” number of spaces per blank
- The remainder is the number of blanks (starting from the left) that get an extra space
The exact implementation of this idea differs between Haskell and Rust. Again this rests a lot on the “reverse” differences between Rust vectors and Haskell lists.
The final line has a slightly different (but easier) process. And we should note that the final line will still be in our buffer when we exit the loop! So we shouldn’t forget to add it to the result.
Haskell Solution
We know enough now to jump into our Haskell solution. Our solution should be organized around a loop. Since we go through the input word-by-word, this should follow a fold pattern. So here’s our outline:
justifyText :: [String] -> Int -> [String]
justifyText inputWords maxWidth = ...
where
-- f = ‘final’
(fLine, fWordsInLine, fCharsInLine, result) = foldl loop ([], 0, 0, []) inputWords
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord = ...
Let’s focus in on the choice we have to make in the loop. We need to determine if this new word fits in our current line. So we’ll get its length and add it to the number of characters in the line AND consider the number of words in the line. We count the words too since each word we already have requires at least one space!
-- (maxWidth is still in scope here)
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ...
else ...
How do we fill in these choices? If we don’t overflow the line, we just append the new word, bump the count of the words, and add the new word’s length to the character count.
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ...
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
The overflow case isn’t hard, but it does require us to have a function that can convert our current line into the final string. This function will also take the number of words and characters in this line. Assuming this function exists, we just make this new line, append it to result, and then reset our other stateful values so that they only reflect the “new word” as part of our current line.
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
resultLine = makeLine currentLine wordsInLine charsInLine
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ([newWord], 1, newWordLen, resultLine : currResult)
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
makeLine :: String -> Int -> Int -> String
makeLine = ...
Before we think about the makeLine implementation though, we just about have enough to fill in the rest of the “top” of our function definition. We’d just need another function for making the “final” line, since this is different from other lines. Then when we get our “final” state values, we’ll plug them into this function to get our final line, append this to the result, and reverse it all.
justifyText :: [String] -> Int -> [String]
justifyText inputWords maxWidth =
reverse (makeLineFinal flLine fWordsInLine fCharsInLine : result)
where
(fLine, fWordsInLine, fCharsInLine, result) = foldl loop ([], 0, 0, []) inputWords
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
resultLine = makeLine currentLine wordsInLine charsInLine
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ([newWord], 1, newWordLen, resultLine : currResult)
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
makeLine :: [String] -> Int -> Int -> String
makeLine = ...
makeLineFinal :: [String] -> Int -> Int -> String
makeLineFinal = ...
Now let’s discuss forming these lines, starting with the general case. We can start with a couple edge cases. This should never be called with an empty list. And with a singleton, we just left-align the word and add the right number of spaces:
makeLine :: [String] -> Int -> Int -> String
makeLine [] _ _ = error "Cannot makeLine with empty string!"
makeLine [onlyWord] _ charsInLine =
let extraSpaces = replicate (maxWidth - charsInLine) ' '
in onlyWord <> extraSpaces
makeLine (first : rest) wordsInLine charsInLine = ...
Now we’ll calculate the quotient and remainder to get the spacing sizes, as mentioned in our algorithm section. But how do we combine them? There are multiple ways, but the idea I thought of was to zip the tail of the list with the number of spaces it needs to append. Then we can fold it into a resulting list using a function like this:
-- (String, Int) is the next string and the number of spaces after it
combine :: String -> (String, Int) -> String
combine suffix (nextWord, numSpaces) =
nextWord <> replicate numSpaces ' ' <> suffix
Remember while doing this that we’ve accumulated the words for each line in reverse order. So we want to append each one in succession, together with the number of spaces that come after it.
To use this function, we can “fold” over the “tail” of our current line, while using the first word in our list as the base of the fold! Don’t forget the quotRem math going on in here!
makeLine :: [String] -> Int -> Int -> String
makeLine [] _ _ = error "Cannot makeLine with empty string!"
makeLine [onlyWord] _ charsInLine =
let extraSpaces = replicate (maxWidth - charsInLine) ' '
in onlyWord <> extraSpaces
makeLine (first : rest) wordsInLine charsInLine = ...
let (baseNumSpaces, numWithExtraSpace) = quotRem (maxWidth - charsInLine) (wordsInLine - 1)
baseSpaces = replicate (wordsInLine - 1 - numWithExtraSpace) baseNumSpaces
extraSpaces = replicate numWithExtraSpace (baseNumSpaces + 1)
wordsWithSpaces = zip rest (baseSpaces <> extraSpaces)
in foldl combine first wordsWithSpaces
combine :: String -> (String, Int) -> String
combine suffix (nextWord, numSpaces) =
nextWord <> replicate numSpaces ' ' <> suffix
To make the final line, we can also leverage our combine function! It’s just a matter of combining each word in our input with the appropriate number of spaces. In this case, almost every word gets 1 space except for the last one (which comes first in our list). This just gets however many trailing spaces we need!
makeLineFinal :: [String] -> Int -> Int -> String
makeLineFinal [] _ _ = error "Cannot makeLine with empty string!"
makeLineFinal strs wordsInLine charsInLine =
let trailingSpaces = maxWidth - charsInLine - (wordsInLine - 1)
in foldl combine "" (zip strs (trailingSpaces : repeat 1))
Putting all these pieces together, we have our complete solution!
justifyText :: [String] -> Int -> [String]
justifyText inputWords maxWidth =
reverse (makeLineFinal flLine fWordsInLine fCharsInLine : result)
where
(fLine, fWordsInLine, fCharsInLine, result) = foldl loop ([], 0, 0, []) inputWords
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
resultLine = makeLine currentLine wordsInLine charsInLine
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ([newWord], 1, newWordLen, resultLine : currResult)
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
makeLine :: [String] -> Int -> Int -> String
makeLine [] _ _ = error "Cannot makeLine with empty string!"
makeLine [onlyWord] _ charsInLine =
let extraSpaces = replicate (maxWidth - charsInLine) ' '
in onlyWord <> extraSpaces
makeLine (first : rest) wordsInLine charsInLine =
let (baseNumSpaces, numWithExtraSpace) = quotRem (maxWidth - charsInLine) (wordsInLine - 1)
baseSpaces = replicate (wordsInLine - 1 - numWithExtraSpace) baseNumSpaces
extraSpaces = replicate numWithExtraSpace (baseNumSpaces + 1)
wordsWithSpaces = zip rest (baseSpaces <> extraSpaces)
in foldl combine first wordsWithSpaces
makeLineFinal :: [String] -> Int -> Int -> String
makeLineFinal [] _ _ = error "Cannot makeLine with empty string!"
makeLineFinal strs wordsInLine charsInLine =
let trailingSpaces = maxWidth - charsInLine - (wordsInLine - 1)
in foldl combine "" (zip strs (trailingSpaces : repeat 1))
combine :: String -> (String, Int) -> String
combine suffix (nextWord, numSpaces) = nextWord <> replicate numSpaces ' ' <> suffix
Rust Solution
Now let’s put together our Rust solution. Since we have a reasonable outline from writing this in Haskell, let’s start with the simpler elements, makeLine and makeLineFinal. We’ll use library functions as much as possible for the string manipulation. For example, we can start makeLineFinal by using join on our input vector of strings.
pub fn make_line_final(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = currentLine.join(" ");
...
}
Now we just need to calculate the number of trailing spaces, subtracting the number of characters in the joined string. We append this to the end by taking a blank space and using repeat for the correct number of times.
pub fn make_line_final(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = currentLine.join(" ");
let trailingSpaces = max_width - result.len();
result.push_str(&" ".repeat(trailingSpaces));
return result;
}
For those unfamiliar with Rust, the type of our input vector might seem odd. When we have &Vec<&str>, this means a reference to a vector of string slices. String slices are portions of a String that we hold a reference to, but they aren’t copied. However, when we join them, we make a new String result.
Also note that we aren’t passing wordsInLine as a separate parameter. We can get this value using .len() in constant time in Rust. In Haskell, length is O(n) so we don’t want to always do that.
Now for the general make_line function, we have the same type signature, but we start with our base case, where we only have one string in our current line. Again, we use repeat with the number of spaces.
pub fn make_line(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = String::new();
let n = currentLine.len();
if (n == 1) {
result.push_str(currentLine[0]);
result.push_str(&" ".repeat(max_width - charsInLine));
return result;
}
...
}
Now we do the “math” portion of this. Rust doesn’t have a single quotRem function in its base library, so we calculate these values separately.
pub fn make_line(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = String::new();
let n = currentLine.len();
if (n == 1) {
result.push_str(currentLine[0]);
result.push_str(&" ".repeat(max_width - charsInLine));
return result;
}
let numSpaces = (max_width - charsInLine);
let baseNumSpaces = numSpaces / (n - 1);
let numWithExtraSpace = numSpaces % (n - 1);
let mut i = 0;
while i < n {
...
}
return result;
}
The while loop we’ll write here is instructive. We use an index instead of a for each pattern because the index tells us how many spaces to use. If our index is smaller than numWithExtraSpace, we add 1 to the base number of spaces. Otherwise we use the base until the index n - 1. This index has no extra spaces, so we’re done at that point!
pub fn make_line(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = String::new();
let n = currentLine.len();
if (n == 1) {
result.push_str(currentLine[0]);
result.push_str(&" ".repeat(max_width - charsInLine));
return result;
}
let numSpaces = (max_width - charsInLine);
let baseNumSpaces = numSpaces / (n - 1);
let numWithExtraSpace = numSpaces % (n - 1);
let mut i = 0;
while i < n {
result.push_str(currentLine[i]);
if i < numWithExtraSpace {
result.push_str(&" ".repeat(baseNumSpaces + 1));
} else if i < n - 1 {
result.push_str(&" ".repeat(baseNumSpaces));
}
i += 1;
}
return result;
}
Now we frame our solution. Let’s start by setting up our state variables (again, omitting numWordsInLine). We’ll also redefine max_width as a usize value for ease of comparison later.
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
...
}
Now we’d like to frame our solution as a “for each” loop. However, this doesn’t work, for Rust-related reasons we’ll describe after the solution! Instead, we’ll use an index loop.
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
...
}
}
We’ll get the word by index on each iteration, and use its length to see if we’ll exceed the max width. If not, we can safely push it onto currentLine and increase the character count:
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
let word = &words[i];
if word.len() + charsInLine + currentLine.len() > mw {
...
} else {
currentLine.push(&words[i]);
charsInLine += word.len();
}
}
}
Now when we do exceed the max width, we have to push our current line onto result (calling make_line). We clear the current line, push our new word, and use its length for charsInLine.
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
let word = &words[i];
if word.len() + charsInLine + currentLine.len() > mw {
result.push(make_line(¤tLine, mw, charsInLine));
currentLine.clear();
currentLine.push(&words[i]);
charsInLine = word.len();
} else {
currentLine.push(&words[i]);
charsInLine += word.len();
}
}
...
}
After our loop, we’ll just call make_line_final on whatever is left in our currentLine! Here’s our complete full_justify function that calls make_line and make_line_final as we wrote above:
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
let word = &words[i];
if word.len() + charsInLine + currentLine.len() > mw {
result.push(make_line(¤tLine, mw, charsInLine));
currentLine.clear();
currentLine.push(&words[i]);
charsInLine = word.len();
} else {
currentLine.push(&words[i]);
charsInLine += word.len();
}
}
result.push(make_line_final(¤tLine, mw, charsInLine));
return result;
}
Why an Index Loop?
Inside our Rust loop, we have an odd pattern in getting the “word” for this iteration. We first assign word = &words[i], and then later on, when we push that word, we reference words[i] again, using currentLine.push(&words[i]).
Why do this? Why not currentLen.push(word)? And then, why can’t we just do for word in words as our loop?
If we write our loop as for word in words, then we cannot reference the value word after the loop. It is “scoped” to the loop. However, currentLine “outlives” the loop! We have to reference currentLine at the end when we make our final line.
To get around this, we would basically have to copy the word instead of using a string reference &str, but this is unnecessarily expensive.
These are the sorts of odd “lifetime” quirks you have to learn to deal with in Rust. Haskell is easier in that it spares us from thinking about this. But Rust gains a significant performance boost with these sorts of ideas.
Conclusion
This was definitely the most involved problem we’ve dealt with so far. We learned a new paradigm (buffer and save), and got some experience dealing with some of the odd quirks and edge cases of string manipulation, especially in Rust. It was a fairly tricky problem, as far as list manipulation goes. For an easier example of a buffer and save problem, try solving Merge Intervals.
If you want to level up your Haskell problem solving skills, you need to take our course Solve.hs. This course will teach you everything you need to know about problem solving, data structures, and algorithms in Haskell. After this course, you’ll be in great shape to deal with these sorts of LeetCode style problems as they come up in your projects.
The Sliding Window in Haskell & Rust
In last week’s problem, we covered a two-pointer algorithm, and compared Rust and Haskell solutions as we have been for this whole series. Today, we’ll study a related concept, the sliding window problem. Whereas the general two-pointer problem can often be tackled by a single loop, we’ll have to use nested loops in this problem. This problem will also mark our first use of the Set data structure in this series.
If you want a deeper look at problem solving techniques in Haskell, you should enroll in our Solve.hs course! You’ll learn everything you need for general problem solving knowledge in Haskell, including data structures, algorithms, and parsing!
The Problem
Today’s LeetCode problem is Longest Substring without Repeating Characters. It’s a lengthy problem name, but the name basically tells you everything you need to know! We want to find a substring of our input that does not repeat any characters within the substring, and then get the longest such substring.
For example, abaca would give us an answer of 3, since we have the substringbac that consists of 3 unique characters. However, abaaca only gives us 2. There is no run of 3 characters where the three characters are all unique.
The Algorithm
The approach we’ll use, as mentioned above, is called a sliding window algorithm. In some ways, this is similar to the two-pointer approach last week. We’ll have, in a sense, two different pointers within our input. One dictates the “left end” of a window and one dictates the “right end” of a window. Unlike last week’s problem though, both pointers will move in the same direction, rather than converging from opposite directions.
The goal of a sliding window problem is “find a continuous subsequence of an input that matches the criteria”. And for many problems like ours, you want to find the longest such subsequence. The main process for a sliding window problem is this:
- Grow the window by increasing the “right end” until (or while) the predicate is satisfied
- Once you cannot grow the window any more, shrink the window by increasing the “left end” until we’re in a position to grow the window again.
- Continue until one or both pointers go off the end of the input list.
So for our problem today, we want to “grow” our sliding window as long as we can get more unique characters. Once we hit a character we’ve already seen in our current window, we’ll need to shrink the window until that duplicate character is removed from the set.
As we’re doing this, we’ll need to keep track of the largest substring size we’ve seen so far.
Here are the steps we would take with the input abaca. At each step, we process a new input character.
1. Index 0 (‘a’) - window is “a” which is all unique.
2. Index 1 (‘b’) - window is “ab” which is all unique
3. Index 2 (‘a’) - window is “aba”, which is not all unique
3b. Shrink window, removing first ‘a’, so it is now “ba”
4. Index 3 (‘c’) - window is “bac”, which is all unique
5. Index 4 (‘a’) - window is “baca”, which is not unique
5b. Shrink window, remove ‘b’ and ‘a’, leaving “ca”
The largest unique window we saw was bac, so the final answer is 3.
Haskell Solution
For a change of pace, let’s discuss the Haskell approach first. Our algorithm is laid out in such a way that we can process one character at a time. Each character either grows the window, or forces it to shrink to accommodate the character. This means we can use a fold!
Let’s think about what state we need to track within this fold. Naturally, we want to track the current “set” of characters in our window. Each time we see the next character, we have to quickly determine if it’s already in the window. We’ll also want to track the largest set size we’ve seen so far, since by the end of the string our window might no longer reflect the largest subsequence.
With a general sliding window approach, you would also need to track both the start and the end index of your current window. In this problem though, we can get away with just tracking the start index. We can always derive the end index by taking the start index and adding the size of the set. And since we’re iterating through the characters anyway, we don’t need the end index to get the “next” character.
This means our fold-loop function will have this type signature:
-- State: (start index, set of letters, largest seen)
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
Now, using our idea of “beginning from the end”, we can already write the invocation of this loop:
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
...
Using 0 for the start index right away is a little hand-wavy, since we haven’t actually added the first character to our set yet! But if we see a single character, we’ll always add it, and as we’ll see, the “adding” branch of our loop never increases this number.
With that in mind, let’s write this branch of our loop handler! If we have not seen the next character in the string, we keep the same start index (left side of the window isn’t moving), we add the character to our set, and we take the new size of the set as the “best” value if it’s greater than the original. We get the new size by adding 1 to the original set size.
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
loop (startIndex, charSet, bestSoFar) c = if S.notMember c charSet
then (startIndex, S.insert c charSet, max bestSoFar (S.size charSet + 1))
else ...
Now we reach the tricky case! If we’ve already seen the next character, we need to remove characters from our set until we reach the instance of this character in the set. Since we might need to remove multiple characters, “shrinking” is an iterative process with a variable number of steps. This means it would be a while-loop in most languages, which means we need another recursive function!
The goal of this function is to change two of our stateful values (the start index and the character set) until we can once again have a unique character set with the new input character. So each iteration it takes the existing values for these, and will ultimately return updated values. Here’s its type signature:
shrink :: (Int, S.Set Char) -> Char -> (Int, S.Set Char)
Before we implement this, we can invoke it in our primary loop! When we’ve seen the new character in our set, we shrink the input to match this character, and then return these new stateful values along with our previous best (shrinking never increases the size).
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
loop (startIndex, charSet, bestSoFar) c = if S.notMember c charSet
then (startIndex, S.insert c charSet, max bestSoFar (S.size charSet + 1))
else
let (newStart, newSet) = shrink (startIndex, charSet) c
in (newStart, newSet, bestSoFar)
shrink :: (Int, S.Set Char) -> Char -> (Int, S.Set Char)
shrink = undefined
Now we implement “shrink” by considering the base case and recursive case. In the base case, the character at this index matches the new character we’ve trying to remove. So we can return the same set of characters, but increase the index.
In the recursive case, we still increase the index, but now we remove the character at the start index from the set without replacement. (Note how we need a vector for efficient indexing here).
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
loop (startIndex, charSet, bestSoFar) c = if S.notMember c charSet
then (startIndex, S.insert c charSet, max bestSoFar (S.size charSet + 1))
else
let (newStart, newSet) = shrink (startIndex, charSet) c
in (newStart, newSet, bestSoFar)
shrink :: (Int, S.Set Char) -> Char -> (Int, S.Set Char)
shrink (startIndex, charSet) c =
let nextC = inputV V.! startIndex
// Base Case: nextC is equal to newC
in if nextC == c then (startIndex + 1, charSet)
// Recursive Case: Remove startIndex
else shrink (startIndex + 1, S.delete nextC charSet) c
Now we have a complete Haskell solution!
Rust Solution
Now in our Rust solution, we’ll follow the same pattern we’ve been doing for these problems. We’ll set up our loop variables, write the loop, and handle the different cases in the loop. Because we had the nested recursive “shrink” function in Haskell, this will translate to a “while” loop in Rust, nested within our for-loop.
Here’s how we set up our loop variables:
pub fn length_of_longest_substring(s: String) -> i32 {
let mut best = 0;
let mut startIndex = 0;
let inputV: Vec<char> = s.chars().collect();
let mut charSet = HashSet::new();
for c in s.chars() {
...
}
}
Within the loop, we have the “easy” case, where the next character is not already in our set. We just insert it into our set, and we update best if we have a new maximum.
pub fn length_of_longest_substring(s: String) -> i32 {
let mut best = 0;
let mut startIndex = 0;
let inputV: Vec<char> = s.chars().collect();
let mut charSet = HashSet::new();
for c in s.chars() {
if charSet.contains(&c) {
...
} else {
charSet.insert(c);
best = std::cmp::max(best, charSet.len());
}
}
return best as i32;
}
The Rust-specific oddity is that when we call contains on the HashSet, we must use &c, passing a reference to the character. In C++ we could just copy the character, or it could be handled by the function using const&. But Rust handles these things a little differently.
Now we get to the “tricky” case within our loop. How do we “shrink” our set to consume a new character?
In our case, we’ll actually just use the loop functionality of Rust, which works like while (true), requiring a manual break inside the loop. Our idea is that we’ll inspect the character at the “start” index of our window. If this character is the same as the new character, we will advance the start index (indicating we are dropping the old version), but then we’ll break. Otherwise, we’ll still increase the index, but we’ll remove the other character from the set as well.
Here’s what this loop looks like in relative isolation:
if charSet.contains(&c) {
loop {
// Look at “first” character of window
let nextC = inputV[startIndex];
if (nextC == c) {
// If it’s the new character, we advance past it and break
startIndex += 1;
break;
} else {
// Otherwise, advance AND drop it from the set
startIndex += 1;
charSet.remove(&nextC);
}
}
} else {
...
}
The inner condition (nextC == c) feels a little flimsy to use with a while (true) loop. But it’s perfectly sound because of the invariant that if charSet contains c, we’ll necessarily find nextC == c before startIndex gets too large. We could also write it as a normal while loop, but loop is an interesting Rust-specific idea to bring in here.
Here’s our complete Rust solution!
pub fn length_of_longest_substring(s: String) -> i32 {
let mut best = 0;
let mut startIndex = 0;
let inputV: Vec<char> = s.chars().collect();
let mut charSet = HashSet::new();
for c in s.chars() {
if charSet.contains(&c) {
loop {
let nextC = inputV[startIndex];
if (nextC == c) {
startIndex += 1;
break;
} else {
startIndex += 1;
charSet.remove(&nextC);
}
}
} else {
charSet.insert(c);
best = std::cmp::max(best, charSet.len());
}
}
return best as i32;
}
Conclusion
With today’s problem, we’ve covered another important problem-solving concept: the sliding window. We saw how this approach could work even with a fold in Haskell, considering one character at a time. We also saw how nested loops compare across Haskell and Rust.
For more problem solving tips and tricks, take a look at Solve.hs, our complete course on problem solving, data structures, and algorithms in Haskell. You’ll get tons of practice on problems like these so you can significantly level up your skills!
Two Pointer Algorithms
We’re now on to part 5 of our series comparing Haskell and Rust solutions for LeetCodeproblems. You can also look at the previous parts (Part 1, Part 2, Part 3, Part 4) to get some more context on what we’ve learned so far comparing these two languages.
For a full look at problem solving in Haskell, check out Solve.hs, our latest course! You’ll get full breakdowns on the processes for solving problems in Haskell, from basic list and loop problems to advanced algorithms!
The Problem
Today we’ll be looking at a problem called Trapping Rain Water. In this problem, we’re given a vector of heights, which form a sort of 1-dimensional topology. Our job is to figure out how many units of water could be collected within the topology.
As a very simple example, the input [1,0,2] could collect 1 unit of water. Here’s a visualization of that system, where x shows the topology and o shows water we collect:
x
xox
We can never collect any water over the left or right “edges” of the array, since it would flow off. The middle index of our array though is lower than its neighbors. So we take the lower of these neighboring values, and we see that we can collect 1 unit of water in this system.
For a bigger example that collects water, we might have the input [4, 2, 1, 1, 3, 5]. Here’s what that looks like:
x
x o o o o x
x o o o x x
x x o o x x
x x x x x x
The total water here is 9.
A flat system like [2,2,2], or a system that looks like a peak [1,2,3,2,1] cannot collect any water, so we should return 0 in these cases.
The Algorithm
There are a couple ways to solve this. One approach would be a two-pass solution, similar to what we used in Product of Array Except Self. We loop from the left side, tracking the maximum water we can store in each unit based on its left neighbors. Then we loop again from the right side and compare the maximum we can store based on the right neighbors to the prior value from the left. This solution is O(n) time, but O(n) space as well.
A more optimal solution for this problem is a two-pointer approach that can use O(1) additional space. In this kind of solution, we look at the left and right of the input simultaneously. Each step of the way, we make a decision to either increase the “left pointer” or decrease the “right pointer” until they meet in the middle. Each time we move, we get more information about our solution.
In this particular problem, we’ll track the maximum value we’ve seen from the left side and the maximum value we’ve seen from the right side. As we traverse each index, we update both sides for the current left and right indices if we have a new maximum.
The crucial step is to see that if the current “left max” is smaller than the current “right max”, we know how much water can be stored at the left index. This is just the left max minus the left index. Then we can increment the left index.
If the opposite is true, we calculate how much water can be stored at the right index, and decrease the right index.
So we keep a running tally of these sums, and we end our loop when they meet in the middle.
Rust Solution
We can describe our algorithm as a simple while loop. This loop goes until the left index exceeds the right index. The loop needs to track 5 values:
- Left Index
- Right Index
- Left Max
- Right Max
- Total sum so far
So let’s write the setup portion of the loop:
pub fn trap(height: Vec<i32>) -> i32 {
let mut leftMax = -1;
let mut rightMax = -1;
let mut leftI = 0;
let mut rightI = height.len() - 1;
let mut total = 0;
while leftI <= rightI {
...
}
}
A subtle thing…the constraints on the LeetCode problem are that the length is at least 1. But to handle length 0 cases, we would need a special case. Rust uses unsigned integers for vector length, so taking height.len() - 1 on a length-0 vector would give the maximum integer, and this would mess up our loop and indexing.
Within the while loop, we run the algorithm.
- Adjust
leftMaxandrightMaxif necessary. - If
leftMaxis not larger, recurse, incrementingleftIand adding tototalfrom the left - If
rightMaxis smaller, decrementrightIand add total from the right
And at the end, we return our total!
pub fn trap(height: Vec<i32>) -> i32 {
let n = height.len();
if n <= 1 {
return 0;
}
let mut leftMax = -1;
let mut rightMax = -1;
let mut leftI = 0;
let mut rightI = n - 1;
let mut total = 0;
while leftI <= rightI {
// Step 1
leftMax = std::cmp::max(leftMax, height[leftI]);
rightMax = std::cmp::max(rightMax, height[rightI]);
if leftMax <= rightMax {
// Step 2
total += leftMax - height[leftI];
leftI += 1;
} else {
// Step 3
total += rightMax - height[rightI];
rightI -= 1;
}
}
return total;
}
Haskell Solution
Now that we’ve seen our Rust solution with a single loop, let’s remember our process for translating this idea to Haskell. With a two-pointer loop, the way in which we traverse the elements of the input is unpredictable, thus we need a raw recursive function, rather than a fold or a map.
Since we’re tracking 5 integer values, we’ll want to write a loop function that looks like this:
-- (leftIndex, rightIndex, leftMax, rightMax, sum)
loop :: (Int, Int, Int, Int, Int) -> Int
Knowing this, we can already “start from the end” and figure out how to invoke our loop from the start of our function:
trapWater :: V.Vector Int -> Int
trapWater input = loop (0, n - 1, -1, -1, 0)
where
n = V.length input
loop :: (Int, Int, Int, Int, Int) -> Int
loop = undefined
In writing our recursive loop, we’ll start with the base case. Once leftI is the bigger index, we return the total.
trapWater :: V.Vector Int -> Int
trapWater input = loop (0, n - 1, -1, -1, 0)
where
n = V.length input
loop :: (Int, Int, Int, Int, Int) -> Int
loop (leftI, rightI, leftMax, rightMax, total) = if leftI > rightI then total
else …
Within the else case, we just follow our algorithm, with the same 3 steps we saw with Rust.
trapWater :: V.Vector Int -> Int
trapWater input = loop (0, n - 1, -1, -1, 0)
where
n = V.length input
-- (leftIndex, rightIndex, leftMax, rightMax, sum)
loop :: (Int, Int, Int, Int, Int) -> Int
loop (leftI, rightI, leftMax, rightMax, total) = if leftI > rightI then total
else
-- Step 1
let leftMax' = max leftMax (input V.! leftI)
rightMax' = max rightMax (input V.! rightI)
in if leftMax' <= rightMax'
-- Step 2
then loop (leftI + 1, rightI, leftMax', rightMax', total + leftMax' - input V.! leftI)
-- Step 3
else loop (leftI, rightI - 1, leftMax', rightMax', total + rightMax' - input V.! rightI)
And we have our Haskell solution!
Conclusion
If you’ve been following this whole series so far, hopefully you’re starting to get a feel for comparing basic algorithms in Haskell and Rust (standing as a proxy for most loop-based languages). In general, we can write loops as recursive functions in Haskell, capturing the “state” of the list as the input parameter for that function.
In particular cases where each iteration deals with exactly one element of an input list, we can employ folds as a tool to simplify our functions. But the two-pointer algorithm we explored today falls into the general recursive category.
To learn the details of understanding these problem solving techniques, take a look at our course, Solve.hs! You’ll learn everything from basic loop and list techniques, to advanced data structures and algorithms!