Dynamic Programming Primer
We’re about to start our final stretch of Haskell/Rust LeetCode comparisons (for now). In this group, we’ll do a quick study of some dynamic programming problems, which are a common cause of headache on programming interviews. We’ll do a couple single-dimension problems, and then show DP in multiple dimensions. Haskell has a couple interesting quirks to work out with dynamic programming, so we’ll try to understand that by comparison to Rust.
Dynamic programming is one of a few different algorithms you’ll learn about in Module 3 of Solve.hs, our Haskell problem solving course. Check it out today!
The Problems
Today’s problem is called House Robber. Normally we wouldn’t want to encourage crime, but when people have such a convoluted security set up as this problem suggests, perhaps they can’t complain.
The idea is that we receive a list of integers, representing the value that we can gain from “robbing” each house on a street. The security system is set up so that the police will be alerted if and only if two adjacent houses are robbed. So we could rob every other house, and no police will come.
We are trying to determine the maximum value we can get from robbing these houses without setting off the alarm (by robbing adjacent houses).
Dynamic Programming Introduction
As mentioned in the introduction, we’ll solve this problem using dynamic programming. This term can mean a couple different things, but by and large the idea is that we use answers on smaller portions of the input (going down as far as base cases) to build up to the final answer on the full input.
This can be done from the top down, generally by means of recursion. If we do this, we’ll often want to “cache” answers (Memoization) to certain parts of the problem so we don’t do redundant calculations.
We can also build answers from the bottom up (tabulation). This often takes the form of creating an array and storing answers to smaller queries in this array. Then we loop from the start of this array to the end, which should give us our answer. Our solutions in this series will largely rest on this tabulation idea. However, as we’ll see, we don’t always need an array of all prior answers to do this!
The key in dynamic programming is to define, whether for an array index or a recursive call, exactly what a “partial” solution means. This will help us use partial solutions to build our complete solution.
The Algorithm
Now let’s figure out how we’ll use dynamic programming for our house robbing problem. The broad idea is that we could define two arrays, the “robbed” array and the “unrobbed” array. Each of these should be equal in size to the number of houses on the street. Let’s carefully define what each array means.
Index i
of the “robbed” array should reflect the maximum value we can get from the houses [0..i]
such that we have “robbed” house i
. Then the “unrobbed” array, at index i
, contains the maximum total value we can get from the houses [0..i]
such that we have not robbed house i
.
When it comes to populating these arrays we need to think first about the base cases. Then we need to consider how to build a new case from existing cases we have. With a recursive solution we have the same pattern: base case and recursive case.
The first two indices for each can be trivially calculated; they are our base cases:
robbed[0] = input[0] // Rob house 0
robbed[1] = input[1] // Rob house 1
unrobbed[0] = 0 // Can’t rob any houses
unrobbed[1] = input[0] // Rob house 0
Now we need to build a generic case i
, assuming that we have already calculated all the values from 0 to i - 1
. To calculate robbed[i]
, we assume we are robbing house i
, thus we add input[i]
. If we are robbing house i
we must not have robbed house i - 1
, so we add this value to unrobbed[i - 1]
.
To calculate unrobbed[i]
, we have the option of whether or not we robbed house i - 1
. It may be advantageous to skip two houses in a row! Consider an example like [100, 1, 1, 100]
. So we take the maximum of unrobbed[i - 1]
and robbed[i - 1]
.
This gives us our general case, and so at the end we simply select the maximum of robbed[n - 1]
and unrobbed[n - 1]
.
We’ve been speaking in terms of arrays, but we can observe that we only need the i - 1
value from each array to construct the i
values. This means we don’t actually have to store a complete array, which would take O(n)
memory. Instead we can store the last “robbed” number and the last “unrobbed” number. This makes our solution O(1)
memory.
Haskell Solution
Now let’s write some code, starting with Haskell! LeetCode guarantees that our input is non-empty, but we still need to handle the size-1 case specially:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
...
Now let’s write a recursive loop function that will take our prior two values (robbed and unrobbed) as well as the index. These are the “stateful” values of our loop. We’ll use these to either return the final value, or make a recursive call with new “robbed” and “unrobbed” values.
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = ...
For the “final” case, we see if we have reached the end of our array (i = n
), in which case we return the max of the two values:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = if i == n then max lastRobbed lastUnrobbed
else ...
Now we fill in our recursive case, using the logic discussed in our algorithm:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else ...
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = if i == n then max lastRobbed lastUnrobbed
else
let newRobbed = nums V.! i + lastUnrobbed
newUnrobbed = max lastRobbed lastUnrobbed
in loop (newRobbed, newUnrobbed) (i + 1)
Finally, we make the initial call to loop
to get our answer! This completes our Haskell solution:
robHouse :: V.Vector Int -> Int
robHouse nums = if n == 1 then nums V.! 0
else loop (nums V.! 1, nums V.! 0) 2
where
n = V.length nums
loop :: (Int, Int) -> Int -> Int
loop (lastRobbed, lastUnrobbed) i = if i == n then max lastRobbed lastUnrobbed
else
let newRobbed = nums V.! i + lastUnrobbed
newUnrobbed = max lastRobbed lastUnrobbed
in loop (newRobbed, newUnrobbed) (i + 1)
Even when tabulating from the ground up in Haskell, we can still use recursion!
Rust Solution
Our Rust solution is similar, just using a loop instead of a recursive function. We start by handling our edge case and coming up with the initial values for “last robbed” and “last unrobbed”.
pub fn rob(nums: Vec<i32>) -> i32 {
let n = nums.len();
if n == 1 {
return nums[0];
}
let mut lastRobbed = nums[1];
let mut lastUnrobbed = nums[0];
...
}
Now we just apply our algorithmic logic in a loop from 2 to n
, resetting lastRobbed
and lastUnrobbed
each time.
pub fn rob(nums: Vec<i32>) -> i32 {
let n = nums.len();
if n == 1 {
return nums[0];
}
let mut lastRobbed = nums[1];
let mut lastUnrobbed = nums[0];
for i in 2..n {
let newRobbed = nums[i] + lastUnrobbed;
let newUnrobbed = std::cmp::max(lastUnrobbed, lastRobbed);
lastRobbed = newRobbed;
lastUnrobbed = newUnrobbed;
}
return std::cmp::max(lastRobbed, lastUnrobbed);
}
And now we’re done with Rust!
Conclusion
Next week we’ll do a problem that actually requires us to store a full array of prior solutions. To learn the different stages in building up an understanding of dynamic programming, you should take our problem solving course, Solve.hs. Module 3 focuses on algorithms, including dynamic programming!
Apply the Trie: Word Search
Today will be a nice culmination of some of the work we’ve been doing with data structures and algorithms. In the past few weeks we’ve covered graph algorithms, particularly Depth First Search. And last week, we implemented the Trie data structure from scratch. Today we’ll solve a “Hard” problem (according to LeetCode) that pulls these pieces together!
For a comprehensive study of data structures and algorithms in Haskell, you should take a look at our course, Solve.hs. You’ll spend a full module on data structures in Haskell, and then another module learning about algorithms, especially graph algorithms!
The Problem
Today’s problem is Word Search II. In the first version of this problem, LeetCode asks you to determine if we can find a single word in a grid. In this second version, we receive an entire list of words, and we have to return the subset of those words that can be found in the grid.
Now the “search” mechanism on this grid is not simply a straight line search. We’ll use a limited Boggle Search. From each letter in a word, we can make any movement to the next letter, as long as it is horizontal or vertical (Boggle also allows diagonal, but we won’t).
So here’s an example grid:
CAND
XBIY
TENQ
Words like “CAN” and “TEN” are obviously allowed. But we can also use “CAB”, even though it doesn’t form a single straight line. Even better, we can use “CABINET”, snaking through all 3 rows. However, a word like “DIN” is disallowed since it would require a diagonal move. Also, we cannot “re-use” letters in the same word. So “TENET” would not be allowed, as it requires backtracking over the E and T.
We need to find the most efficient way to determine which of our input words can be found in the grid.
The Algorithm
This problem combines two elements we’ve recently worked with. First, we will use DFS ideas to actually search through the grid from a particular starting location. Second, we will use a Trie to store all of our input words. This will help us determine if we can stop searching. Once we find a string that is not a prefix in our Trie, we can discontinue the search branch.
Here’s a run-down of the solution:
- Make a Trie from the Input Words
- Search from each starting location in the grid, trying to add good words to a growing set of results.
- At each location, add the character to our string. See if the resulting string is still a valid prefix of our Trie.
- If it is, add the word to our results if the Trie indicates it is a valid word. Then search the neighboring locations, while keeping track of locations we’ve visited.
- Once we no longer have a valid Trie, we can stop the line of searching
The obscures quite a few details, but from our last few weeks of work, those details shouldn’t be too difficult.
Updating Trie Functions
For both our solutions, we’ll assume we’re using the same Trie
structure we built last week. We’ll need the general structure, as well as the insert
function.
However, we won’t actually need the search
or startsWith
functions. Each time we call these, we traverse the full length of the string we’re querying. And the way our algorithm will work here, that will get quite inefficient (quadratic time overall).
Instead we’re going to rely on directly accessing sub-Tries, so that as we build our search word longer, we’ll get a Trie that assumes we’ve already searched that word in the main Trie. This will make subsequent searches faster.
To make this more convenient, we’ll just provide a function to get a Maybe
Trie from the “sub-Tries” of the node we’re working with, based on the character. We’ll call this “popping” a Trie.
Here’s what it looks like in Rust:
impl Trie {
...
fn pop(&self, c: char) -> Option<&Trie> {
return self.nodes.get(&c);
}
}
And in Haskell:
popTrie :: Char -> Trie -> Maybe Trie
popTrie c (Trie _ subs) = M.lookup c subs
For completeness, here’s the entire Rust version of the Trie that we’ll need for this problem:
use std::collections::HashMap;
use std::str::Chars;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
fn insert(&mut self, word: String) {
self.insertIt(word.chars());
}
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
if !self.nodes.contains_key(&c) {
self.nodes.insert(c, Trie::new());
}
if let Some(subTrie) = self.nodes.get_mut(&c) {
subTrie.insertIt(iter);
}
} else {
self.endsWord = true;
}
}
fn pop(&self, c: char) -> Option<&Trie> {
return self.nodes.get(&c);
}
}
And here’s the full Haskell version:
data Trie = Trie Bool (M.Map Char Trie)
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
in (Trie ends (M.insert c newSub subs))
popTrie :: Char -> Trie -> Maybe Trie
popTrie c (Trie _ subs) = M.lookup c subs
Rust Solution
Now let’s move on to our solution, starting with Rust. As always with a graph problem, we’ll benefit from having a neighbors
function. This will be very similar to functions we’ve written in the past few weeks, so we won’t dwell on it. In this case though, we’ll incorporate the visited
set directly into this function, and exclude neighbors we’ve already seen:
pub fn neighbors(
nr: usize,
nc: usize,
visitedLocs: &HashSet<(usize, usize)>,
loc: (usize, usize)) -> Vec<(usize, usize)> {
let r = loc.0;
let c = loc.1;
let mut results = Vec::new();
if (r > 0 && !visitedLocs.contains(&(r - 1, c))) {
results.push((r - 1, c));
}
if (c > 0 && !visitedLocs.contains(&(r, c - 1))) {
results.push((r, c - 1));
}
if (r + 1 < nr && !visitedLocs.contains(&(r + 1, c))) {
results.push((r + 1, c));
}
if (c + 1 < nc && !visitedLocs.contains(&(r, c + 1))) {
results.push((r, c + 1));
}
return results;
}
Now, thinking back to the Islands example, we want to write a search
function similar to the visit
function we had before. The job of the visit
function was to populate the visited set with all reachable tiles from the start. Our search
function will populate a set of “results” with every word reachable from a certain location.
However, it will also require some immutable inputs, such as the board dimensions and the board itself. But it will also have several mutable, stateful items like the current Trie
, the String
we are building, and the current visited set. Here’s the signature we will use:
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
...
}
This function has two tasks. First, assess the current location to see if the word we completed by arriving here should be added, or if it’s at least a prefix of a remaining word. If either is true, our second job is to find this location’s neighbors and recursively call them, continuing to grow our string and search for longer words.
Let’s write the code for the first part. Naturally, we need the character at this grid location. Then we need to query our Trie
to “pop” the sub-trie associated with this character. If this sub-Trie doesn’t exist, we immediately return. Otherwise, we consider this location “visited” (add it to the set) and we push the new character onto the string. If our new sub-Trie “ends a word”, then we add this word to our results set!
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
let c = board[loc.0][loc.1];
if let Some(subTrie) = trie.pop(c) {
currentStr.push(c);
visitedLocs.insert(loc);
if subTrie.endsWord {
seenWords.insert(currentStr.clone());
}
...
}
}
Now we get our neighbors and recursively search
them, passing updated mutable values. But there’s one extra thing to include! After we are done searching our neighbors, we should “undo” our mutable changes to the visited set and the string.
We haven’t had to make this kind of “backtracking” change before. But we don’t want to permanently keep this location in this visited set, nor keep the string modified. When we return to our caller, we want these mutable values to be the same as how we got them. Otherwise, subsequent calls may be disturbed, and we’ll get incorrect answers!
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
let c = board[loc.0][loc.1];
if let Some(subTrie) = trie.pop(c) {
currentStr.push(c);
visitedLocs.insert(loc);
if subTrie.endsWord {
seenWords.insert(currentStr.clone());
}
let ns = neighbors(nr, nc, visitedLocs, loc);
for n in ns {
search(nr, nc, board, subTrie, n, visitedLocs, currentStr, seenWords);
}
// Backtrack! Remove this location and pop this character
visitedLocs.remove(&loc);
currentStr.pop();
}
}
This completes the search
function. Now we just have to call it! We start our primary function by initializing our key values, especially our Trie
. We need to insert all the starting words into a Trie that we create:
pub fn find_words(board: Vec<Vec<char>>, words: Vec<String>) -> Vec<String> {
let mut trie = Trie::new();
for word in &words {
trie.insert(word.to_string());
}
let mut results = HashSet::new();
let nr = board.len();
let nc = board[0].len();
...
}
And now we just loop through each location in the grid and search
it as a starting location! For an extra optimization, we can stop our search early if we have found all of our words.
pub fn find_words(board: Vec<Vec<char>>, words: Vec<String>) -> Vec<String> {
let mut trie = Trie::new();
for word in &words {
trie.insert(word.to_string());
}
let mut results = HashSet::new();
let nr = board.len();
let nc = board[0].len();
for i in 0..nr {
for j in 0..nc {
if results.len() < words.len() {
let mut visited = HashSet::new();
let mut curr = String::new();
search(nr, nc, &board, &trie, (i, j), &mut visited, &mut curr, &mut results);
}
}
}
return results.into_iter().collect();
}
And we’re done! Here is our complete Rust solution!
pub fn neighbors(
nr: usize,
nc: usize,
visitedLocs: &HashSet<(usize, usize)>,
loc: (usize, usize)) -> Vec<(usize, usize)> {
let r = loc.0;
let c = loc.1;
let mut results = Vec::new();
if (r > 0 && !visitedLocs.contains(&(r - 1, c))) {
results.push((r - 1, c));
}
if (c > 0 && !visitedLocs.contains(&(r, c - 1))) {
results.push((r, c - 1));
}
if (r + 1 < nr && !visitedLocs.contains(&(r + 1, c))) {
results.push((r + 1, c));
}
if (c + 1 < nc && !visitedLocs.contains(&(r, c + 1))) {
results.push((r, c + 1));
}
return results;
}
pub fn search(
nr: usize,
nc: usize,
board: &Vec<Vec<char>>,
trie: &Trie,
loc: (usize, usize),
visitedLocs: &mut HashSet<(usize, usize)>,
currentStr: &mut String,
seenWords: &mut HashSet<String>) {
let c = board[loc.0][loc.1];
if let Some(subTrie) = trie.pop(c) {
currentStr.push(c);
visitedLocs.insert(loc);
if subTrie.endsWord {
seenWords.insert(currentStr.clone());
}
let ns = neighbors(nr, nc, visitedLocs, loc);
for n in ns {
search(nr, nc, board, subTrie, n, visitedLocs, currentStr, seenWords);
}
visitedLocs.remove(&loc);
currentStr.pop();
}
}
pub fn find_words(board: Vec<Vec<char>>, words: Vec<String>) -> Vec<String> {
let mut trie = Trie::new();
for word in &words {
trie.insert(word.to_string());
}
let mut results = HashSet::new();
let nr = board.len();
let nc = board[0].len();
for i in 0..nr {
for j in 0..nc {
if results.len() < words.len() {
let mut visited = HashSet::new();
let mut curr = String::new();
search(nr, nc, &board, &trie, (i, j), &mut visited, &mut curr, &mut results);
}
}
}
return results.into_iter().collect();
}
Haskell Solution
Our Haskell solution starts with some of the same beats. We’ll create our initial Trie through insertion and define a familiar neighbors
function:
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds board
trie = foldr insertTrie (Trie False M.empty) allWords
neighbors :: HS.HashSet (Int, Int) -> (Int, Int) -> [(Int, Int)]
neighbors visited (r, c) =
let up = if r > minRow && not (HS.member (r - 1, c) visited) then Just (r - 1, c) else Nothing
left = if c > minCol && not (HS.member (r, c - 1) visited) then Just (r, c - 1) else Nothing
down = if r < maxRow && not (HS.member (r + 1, c) visited) then Just (r + 1, c) else Nothing
right = if c < maxCol && not (HS.member (r, c + 1) visited) then Just (r, c + 1) else Nothing
in catMaybes [up, left, down, right]
...
Now let’s think about our search
function. This function’s job is to update a set, given a particular location. We’re going to loop over this function with many locations. So we want the end of its signature to look like:
(Int, Int) -> HS.HashSet String -> HS.HashSet String
This will allow us to use it with foldr
. But we still want to think about the mutable elements going into this function: the Trie, the visited set, and the accumulated string. These also change from call to call, but they should come earlier in the signature, since we can make them fixed over each loop of the function. So here’s what our type signature looks like:
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
The first order of business in this function is to “pop” the try based on the character at this location and see if it exists. If not, we simply return our original set:
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
...
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) -> ...
Now we update our current string and visited set, while adding the word to our results if the sub-Trie indicates we are at the end of a word:
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
...
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) ->
let currentStr' = board A.! loc : currentStr
visited' = HS.insert loc visited
seenWords' = if ends then HS.insert (reverse currentStr') seenWords else seenWords
...
And now we get our neighbors and loop through them with foldr
. Observe how we define the function f
that fixes the first three parameters with our new mutable values so we can cleanly call foldr
.
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = ...
where
...
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) ->
let currentStr' = board A.! loc : currentStr
visited' = HS.insert loc visited
seenWords' = if ends then HS.insert (reverse currentStr') seenWords else seenWords
ns = neighbors visited loc
f = search sub visited' currentStr'
in foldr f seenWords' ns
We’re almost done now! Having written the “inner” loop, we just have to write the “outer” loop that will loop through every location as a starting point.
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = HS.toList result
where
...
trie = foldr insertTrie (Trie False M.empty) allWords
result = foldr (search trie HS.empty "") HS.empty (A.indices board)
Here is our complete Haskell solution!
findWords :: A.Array (Int, Int) Char -> [String] -> [String]
findWords board allWords = HS.toList result
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds board
trie = foldr insertTrie (Trie False M.empty) allWords
neighbors :: HS.HashSet (Int, Int) -> (Int, Int) -> [(Int, Int)]
neighbors visited (r, c) =
let up = if r > minRow && not (HS.member (r - 1, c) visited) then Just (r - 1, c) else Nothing
left = if c > minCol && not (HS.member (r, c - 1) visited) then Just (r, c - 1) else Nothing
down = if r < maxRow && not (HS.member (r + 1, c) visited) then Just (r + 1, c) else Nothing
right = if c < maxCol && not (HS.member (r, c + 1) visited) then Just (r, c + 1) else Nothing
in catMaybes [up, left, down, right]
search ::
Trie ->
HS.HashSet (Int, Int) ->
String ->
(Int, Int) ->
HS.HashSet String ->
HS.HashSet String
search trie' visited currentStr loc seenWords = case popTrie (board A.! loc) trie' of
Nothing -> seenWords
Just sub@(Trie ends _) ->
let currentStr' = board A.! loc : currentStr
visited' = HS.insert loc visited
seenWords' = if ends then HS.insert (reverse currentStr') seenWords else seenWords
ns = neighbors visited loc
f = search sub visited' currentStr'
in foldr f seenWords' ns
result = foldr (search trie HS.empty "") HS.empty (A.indices board)
Conclusion
This problem brought together a lot of interesting solution components. We applied our Trie implementation from last week, and used several recurring ideas from graph search problems. Next week we’re going to switch gears a bit and start discussing dynamic programming.
To learn more about all of these problem concepts, you need to take a look at Solve.hs. It gives a fairly comprehensive look at problem solving concepts in Haskell. If you want to understand how to shape your functions to work with folds like we did in this article, you’ll learn about that in Module 1. If you want to implement and apply data structures like graphs and tries, Module 2 will teach you. And if you want practice writing and using key graph algorithms in Haskell, Module 3 will give you the experience you need!
Writing Our Own Structure: Tries in Haskell & Rust
In the last few weeks we’ve studied a few different graph problems. Graphs are interesting because they are a derived structure that we can represent in different ways to solve different problems. Today, we’ll solve a LeetCode problem that actually focuses on writing a data structure ourselves to satisfy certain requirements! Next week, we’ll use this structure to solve a problem.
If you want to improve your Haskell Data Structure skills, both with built-in types and in making your own types, your should check out Solve.hs, our problem solving course. Module 2 is heavily focused on data structures, so it will clear up a lot of blind spots you might have working with these in Haskell!
The Problem
Unlike our previous problems, we’re not trying to solve some peculiar question formulation with inputs and outputs. Our task today is to implement the basic functions for a Trie data structure. A Trie (pronounced “try”) is also known as a Prefix tree. They are most often used in the context of strings (though other stream-like types are also possible). We’ll make one that is effectively a container of strings that efficiently supports 3 operations:
- Insert - Add a new word into our set
- Search - Determine if we have previously inserted the given word into our tree
- Starts With - Determine if we have inserted any word that has the given input as a prefix
The first two operations are typical of tree sets and hash sets, but the third operation is distinctive for a Trie.
We’ll implement these three functions, as well as provide a means of constructing an empty Trie. We’ll work with the constraint that all our input strings consist only of lowercase English letters.
The Algorithm
If at first you pronounced “Trie” like “Tree”, you’re not really wrong. Our core implementation strategy will be to create a recursive tree structure. It’s easiest if we start by visualizing a trie. Here is a tree structure corresponding to a Trie containing the words “at”, “ate”, “an”, “bank” and “band”.
:
_
/ \
a b
/ \ \
t* n* a
/ \
e* n
/ \
k* d*
The top node is a blank space _
representing the root node. All other nodes in our tree correspond to letters. When we trace a path from the root to any node, we get a valid prefix of a word in our Trie. A star (*
) indicates that a node corresponds to a complete word. Note that interior nodes can be complete, as is the case with at
.
This suggests a structure for each node in the Trie. A node should store a boolean value telling us if the node completes a word. Other than that, all it needs is a map keying from characters to other nodes further down the tree.
We can use this structure to write all our function implementations in relatively simple recursive terms.
Haskell
Recursive data structures and functions are very natural in Haskell, so we’ll start with that implementation. We’ll make our data type and provide it with the two fields…the boolean indicating the end of a word, and the map of characters to additional Trie
nodes.
data Trie = Trie Bool (M.Map Char Trie)
Note that even though a node is visually represented by a particular character, we don’t actually need to store the character on the node. The fact that we arrive at a particular node by keying on a character from its parent is enough.
Now let’s write our implementations, starting with insert
. In Haskell, we don’t write functions “on” the type because we can’t mutate expressions. We write functions that take one instance of the type and return another. So our insertTrie
function has this signature:
insertTrie :: String -> Trie -> Trie
We want to build a recursive structure, and we have an input (String
) that breaks down recursively. This means we have two cases to deal with in this function. Either the string is empty, or it is not:
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie ends subs) = ...
insertTrie (c : cs) (Trie ends subs) = ...
If the string is empty, we don’t need to do much. We return a node with the same sub-tries, but its Bool
field is True
now! It’s good to remark on certain edge cases. For example, this tells us that we can insert the “empty” string into our Trie. We would just mark the “root” node as True
!
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) = ...
In the recursive case, we’ll be making a recursive call on a particular “sub” Trie. We want to “lookup” if a Trie for the character c
already exists. If not, we’ll make a default one (with False
and empty sub-tries map).
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
...
Now we recursively insert the rest of the input into this sub-Trie:
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
...
Finally, we map this newSub
into the original Trie
, using c
as the key for it:
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
in (Trie ends (M.insert c newSub subs))
The search
and startsWith
functions follow a similar pattern, pattern matching on the input string. With search
, an empty string keys us to look at the Bool
field for our Trie. If it’s True
, then the word we are searching for was inserted into our Trie:
searchTrie :: String -> Trie -> Bool
searchTrie [] (Trie ends _) = ends
searchTrie (c : cs) (Trie _ subs) = ...
If not, we’ll check for the sub-Trie using the character c
. If it doesn’t exist, then the word we’re looking for isn’t in our Trie. If it does, we recursively search for the “rest” of the string in that Trie:
searchTrie :: String -> Trie -> Bool
searchTrie [] (Trie ends _) = ends
searchTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> searchTrie cs sub
Finally, startsWith
is almost identical to search
. The only difference is that if we reach the end of the input word, we always return True
, as it only needs to be a prefix:
startsWithTrie :: String -> Trie -> Bool
startsWithTrie [] _ = True
startsWithTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> startsWithTrie cs sub
There is one interesting case here. We always consider the empty string to be a valid prefix, even if we haven’t inserted anything into our Trie. Perhaps this doesn’t make sense to you, but LeetCode accepts this logic with our Rust solution. You could work around to accommodate it, but it results in code that is less clean.
Here’s the full Haskell solution:
data Trie = Trie Bool (M.Map Char Trie)
insertTrie :: String -> Trie -> Trie
insertTrie [] (Trie _ subs) = Trie True subs
insertTrie (c : cs) (Trie ends subs) =
let sub = fromMaybe (Trie False M.empty) (M.lookup c subs)
newSub = insertTrie cs sub
in (Trie ends (M.insert c newSub subs))
searchTrie :: String -> Trie -> Bool
searchTrie [] (Trie ends _) = ends
searchTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> searchTrie cs sub
startsWithTrie :: String -> Trie -> Bool
startsWithTrie [] _ = True
startsWithTrie (c : cs) (Trie _ subs) = case M.lookup c subs of
Nothing -> False
Just sub -> startsWithTrie cs sub
Rust
The Rust solution follows the same algorithmic ideas, but the code looks quite a bit different. Rust allows mutable data structures, so it looks a bit more like your typical object oriented language in making structures. But there are some interesting quirks! Here’s the frame LeetCode gives you to work with:
struct Trie {
...
}
impl Trie {
fn new() -> Self {
...
}
fn insert(&mut self, word: String) {
...
}
fn search(&self, word: String) -> bool {
...
}
fn starts_with(&self, prefix: String) -> bool {
...
}
}
With Rust we define the fields of the struct
, and then create an impl
for the type with all the relevant functions. Each “class” method takes a self
parameter, somewhat like Python. This can be mutable or not. A Rust “constructor” is typically done with a new
function, as you see.
Despite Rust’s trickiness with ownership and borrowing, there are no obstacles to making a recursive data type in the same manner as our Haskell implementation. Here’s our struct definition, as well as the constructor:
use std::collections::HashMap;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
...
}
When it comes to the main functions though, we don’t want to make them directly recursive. Each takes a String
input, and constructing a new String
that pops the first character actually isn’t efficient. Rust is also peculiar in that you cannot index into strings, due to ambiguity arising from character encodings. So our solution will be to use the Chars
iterator to efficiently “pop” characters while being able to examine the characters that come next.
So let’s start by making ...It
versions of all our functions that take Chars
iterators. We’ll be able to call these functions recursively. We invoke each one from the base function by calling .chars()
on the input string.
use std::collections::HashMap;
use std::str::Chars;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
fn insert(&mut self, word: String) {
self.insertIt(word.chars());
}
fn insertIt(&mut self, mut iter: Chars) {
...
}
fn search(&self, word: String) -> bool {
return self.searchIt(word.chars());
}
fn searchIt(&self, mut iter: Chars) -> bool {
...
}
fn starts_with(&self, prefix: String) -> bool {
return self.startsWithIt(prefix.chars());
}
fn startsWithIt(&self, mut iter: Chars) -> bool {
...
}
}
Now let’s zero in on these implementations, starting with insertIt
. First we pattern match on the iterator. If it gives us None
, we just end self.endsWith = true
and we’re done.
impl Trie {
...
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
...
} else {
self.endsWord = true;
}
}
}
Now, just like in Haskell, if this node doesn’t have a sub-Trie for the character c
yet, we insert a new Trie for c
. Then we recursively call insertIt
on this “subTrie”.
impl Trie {
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
if !self.nodes.contains_key(&c) {
self.nodes.insert(c, Trie::new());
}
if let Some(subTrie) = self.nodes.get_mut(&c) {
subTrie.insertIt(iter);
}
} else {
self.endsWord = true;
}
}
}
That’s it for insertion. Now for searching, we’ll follow the same pattern matching protocol. If the Chars
iterator is empty, we just check if this node has endsWith
set or not:
impl Trie {
fn searchIt(&self, mut iter: Chars) -> bool {
...
} else {
return self.endsWord;
}
}
}
We check again for a sub-Trie under the key c
. If it doesn’t exist, we return false
. If it does, we just make a recursive call!
impl Trie {
fn searchIt(&self, mut iter: Chars) -> bool {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.searchIt(iter);
} else {
return false;
}
} else {
return self.endsWord;
}
}
}
And with startsWith
, it’s the same pattern. We do exactly the same thing as search
, except that the else
case is unambiguously true.
impl Trie {
fn startsWithIt(&self, mut iter: Chars) -> bool {
if let Some(c) = iter.next() {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.startsWithIt(iter);
} else {
return false;
}
} else {
return true;
}
}
}
Here’s our final Rust implementation:
use std::collections::HashMap;
use std::str::Chars;
struct Trie {
endsWord: bool,
nodes: HashMap<char, Trie>
}
impl Trie {
fn new() -> Self {
Trie {
endsWord: false,
nodes: HashMap::new()
}
}
fn insert(&mut self, word: String) {
self.insertIt(word.chars());
}
fn insertIt(&mut self, mut iter: Chars) {
if let Some(c) = iter.next() {
if !self.nodes.contains_key(&c) {
self.nodes.insert(c, Trie::new());
}
if let Some(subTrie) = self.nodes.get_mut(&c) {
subTrie.insertIt(iter);
}
} else {
self.endsWord = true;
}
}
fn search(&self, word: String) -> bool {
return self.searchIt(word.chars());
}
fn searchIt(&self, mut iter: Chars) -> bool {
if let Some(c) = iter.next() {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.searchIt(iter);
} else {
return false;
}
} else {
return self.endsWord;
}
}
fn starts_with(&self, prefix: String) -> bool {
return self.startsWithIt(prefix.chars());
}
fn startsWithIt(&self, mut iter: Chars) -> bool {
if let Some(c) = iter.next() {
if let Some(subTrie) = self.nodes.get(&c) {
return subTrie.startsWithIt(iter);
} else {
return false;
}
} else {
return true;
}
}
}
Conclusion
It’s always interesting to practice making recursive data structures in a new language. While Rust shares some things in common with Haskell, making data structures still feels more like other object-oriented languages than Haskell. Next week, we’ll put our Trie implementation to use by solving a problem that requires this data structure!
To learn more about writing your own data structures in Haskell, take our course, Solve.hs! Module 2 focuses heavily on data structures, and you’ll learn how to make some derived data structures that will improve your programs!
Topological Sort: Managing Mutable Structures in Haskell
Welcome back to our Rust vs. Haskell comparison series, featuring some of the most common LeetCode questions. We’ve done a couple graph problems the last two weeks, involving DFS and BFS.
Today we’ll do a graph problem involving a slightly more complicated algorithm. We’ll also use a couple data structures we haven’t seen in this series yet, and we’ll see how tricky it can get to have multiple mutable structures in a Haskell algorithm.
To learn all the details of managing your data structures in Haskell, check out Solve.hs, our problem solving course. You’ll learn all the key APIs, important algorithms, and you’ll get a lot of practice with LeetCode style questions!
The Problem
Today’s problem is called Course Schedule. We are given a number of courses, and a list of prerequisites among those courses. For a prerequisite pair (A,B)
, we cannot take Course A until we have taken Course B. Our job is to determine, in a sense, if the prerequisite list is well-defined. We want to see whether or not the list would actually allow us to take all the courses.
As an example, suppose we had these inputs:
Number Courses: 4
Prerequisites: [(2, 0), (1,0), (3,1), (3,2)]
This is a well defined set of courses. In order to take courses 1 and 2, we must take course 0. Then in order to take course 3, we have to take courses 1 and 2. So if we have the ordering 0->1->2->3, we can take all the courses. So we would return True
.
However, if we were to add (1,3)
there, we would not be able to take all the courses. We could take courses 0 and 2, but then we would be stuck because 1 and 3 have a mutual dependency. So we would return False
with this list.
We are guaranteed that the course indices in the prerequisites list are in the range [0, numCourses - 1]
. We are also guaranteed that all prerequisites are unique.
The Algorithm
For our algorithm, we will image these courses as living in a directed graph. If course A is a prerequisite of Course B, there should be a directed edge from A to B. This problem essentially boils down to determining if this graph has a cycle or not.
There are many ways to approach this, including relying on DFS or BFS as we discussed in the past two weeks! However, to introduce a new idea, we’ll solve this problem using the idea of topological sorting.
We can think of nodes as having “in degrees”. The “in degree” of a node is the number of directed edges coming into it. We are particularly concerned with nodes that have an in degree of 0. These are courses with no prerequisites, which we can take immediately.
Each time we “take” a course, we can increment a count of the courses we’ve taken, and then we can “remove” that node from the graph by decrementing the in degrees of all nodes that it is pointing to. If any of these nodes have their in degrees drop to 0 as a result of this, we can then add them to a queue of “0 degree nodes”.
If, once the queue is exhausted, we’ve taken every course, then we have proven that we can satisfy all the requirements! If not, then there must be a cycle preventing some nodes from ever having in-degree 0.
Rust Solution
We’ll start with a Rust solution. We need to manage a few different structures in this problem. The first two will be vectors giving us information about each course. We want to know the current “in degree” as well as having a list of the courses “unlocked” by each course.
Each “prerequisite” pair gives the unlocked course first, and then the prerequisite course. We’ll call these “post” and “pre”, respectively. We increase the in-degree of “post” and add “post” to the list of courses unlocked by “pre”:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
// More convenient to use usize
let n = num_courses as usize;
let mut inDegrees = Vec::with_capacity(n);
inDegrees.resize(n, 0);
// Maps from “pre” course to “post” course
let mut unlocks: Vec<Vec<usize>> = Vec::with_capacity(n);
unlocks.resize(n, Vec::new());
for req in prerequisites {
let post = req[0] as usize;
let pre = req[1] as usize;
inDegrees[post] += 1;
unlocks[pre].push(post);
}
...
}
Now we need to make a queue of 0-degree nodes. This uses VecDeque
from last time. We’ll go through the initial in-degrees list and add all the nodes that are already 0. Then we’ll set up our loop to pop the front element until empty:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
let n = num_courses as usize;
...
// Make a queue of 0 degree
let mut queue: VecDeque<usize> = VecDeque::new();
for i in 0..(num_courses as usize) {
if inDegrees[i] == 0 {
queue.push_back(i);
}
}
let mut numSatisfied = 0;
while let Some(course) = queue.pop_front() {
...
}
return numSatisfied == num_courses;
}
All we have to do now is process the course at the top of the queue each time now. We always increment the number of courses satisfied, since de-queuing a course indicates we are taking it. Then we loop through unlocks
and decrement each of their in degrees. If reducing an in-degree takes it to 0, then we add this unlocked course to the back of the queue:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
let n = num_courses as usize;
...
let mut numSatisfied = 0;
while let Some(course) = queue.pop_front() {
numSatisfied += 1;
for post in &unlocks[course] {
inDegrees[*post] -= 1;
if (inDegrees[*post] == 0) {
queue.push_back(*post);
}
}
}
return numSatisfied == num_courses;
}
This completes our solution! Here is the full Rust implementation:
pub fn can_finish(num_courses: i32, prerequisites: Vec<Vec<i32>>) -> bool {
let n = num_courses as usize;
// Make a vector with inDegree Count
let mut inDegrees = Vec::with_capacity(n);
inDegrees.resize(n, 0);
// Make a vector of "unlocks"
let mut unlocks: Vec<Vec<usize>> = Vec::with_capacity(n);
unlocks.resize(n, Vec::new());
for req in prerequisites {
let post = req[0] as usize;
let pre = req[1] as usize;
inDegrees[post] += 1;
unlocks[pre].push(post);
}
// Make a queue of 0 degree
let mut queue: VecDeque<usize> = VecDeque::new();
for i in 0..(num_courses as usize) {
if inDegrees[i] == 0 {
queue.push_back(i);
}
}
let mut numSatisfied = 0;
while let Some(course) = queue.pop_front() {
numSatisfied += 1;
for post in &unlocks[course] {
inDegrees[*post] -= 1;
if (inDegrees[*post] == 0) {
queue.push_back(*post);
}
}
}
return numSatisfied == num_courses;
}
Haskell Solution
In Haskell, we can follow this same approach. However, this is a somewhat challenging algorithm for Haskell beginners, because there are a lot of data structure “modifications” occurring, and expressions in Haskell are immutable! So we’ll organize our solution into three different parts:
- Initializing our structures
- Writing loop modifiers
- Writing the loop
This solution will introduce 2 data structures we haven’t used in this series so far. The IntMap
and the Sequence (Seq
), which we’ll use qualified like so:
import qualified Data.IntMap.Lazy as IM
import qualified Data.Sequence as Seq
The IntMap
type works more or less exactly like a normal Map
, with the same API. However, it assumes we have Int
as our key type, which makes certain operations more efficient than a generic ordered map.
Then Seq
is the best thing to use for a FIFO queue. We would have used this last week if we implemented BFS from scratch.
We’ll also make a few type alias, since we’ll be combining these structures and frequently using them in type signatures:
type DegCount = IM.IntMap Int
type CourseMaps = (DegCount, IM.IntMap [Int])
type CourseState = (Int, Seq.Seq Int, DegCount)
The setup to our problem is fairly simple. Our function takes the number of courses as an integer, and the prerequisites as a list of tuples. We’ll write a number of helper functions beneath this top level definition, but for additional clarity, we’ll show them independently as we write them.
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = ...
Initializing Our Structures
Recall that the first part of our Rust solution focused on populating 3 structures:
- The list of in-degrees (per node)
- The list of “unlocks” (per node)
- The initial queue of 0-degree nodes
We use IntMap
s for the first two (and use the alias DegCount
for the first). These are easier to modify than vectors in Haskell. The other noteworthy fact is that we want to create these together (this is why we have the CourseMaps
alias combining them). We process each prerequisite pair, updating both of these maps. This means we want to write a folding function like so:
processPrereq :: (Int, Int) -> CourseMaps -> CourseMaps
For this function, we want to define two more helpers. One that will make it easier to increment the key of a degree value, and one that will make it easy to append a new unlock for the other mapping.
incKey :: Int -> DegCount -> DegCount
appendUnlock :: Int -> Int -> IM.IntMap [Int] -> IM.IntMap [Int]
These two helpers are straightforward to implement. In each case, we check for the key existing. If it doesn’t exist, we insert the default value (either 1 or a singleton list). If it exists, we either increment the value for the degree, or we append the new unlocked course to the existing list.
incKey :: Int -> DegCount -> DegCount
incKey k mp = case IM.lookup k mp of
Nothing -> IM.insert k 1 mp
Just x -> IM.insert k (x + 1) mp
appendUnlock :: Int -> Int -> IM.IntMap [Int] -> IM.IntMap [Int]
appendUnlock pre post mp = case IM.lookup pre mp of
Nothing -> IM.insert pre [post] mp
Just prev -> IM.insert pre (post : prev) mp
Now it’s very tidy to implement our folding function, and apply it to get these initial values:
processPrereq :: (Int, Int) -> CourseMaps -> CourseMaps
processPrereq (post, pre) (inDegrees', unlocks') =
(incKey post inDegrees', appendUnlock pre post unlocks')
Here’s where our function currently is then:
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs =
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
Now we want to build our initial queue as well. For this, we just want to loop through the possible course numbers, and add any that are not in the map for inDegrees
(we never insert something with a value of 0).
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs =
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
Writing Loop Modifiers
Now we have to consider what structures are going to be part of our “loop” and how we’re going to modify them. The type alias CourseState
already expresses our loop state. We want to track the number of courses satisfied so far, the queue of 0-degree nodes, and the remaining in-degree values.
The key modification is that we can reduce the in-degrees of remaining courses. When we do this, we want to know immediately if we reduced the in-degree to 0. So let’s write a function that decrements the value, except that it deletes the key entirely if it drops to 0. We’ll return a boolean indicating if the key no longer exists in the map after this process:
decKey :: Int -> DegCount -> (DegCount, Bool)
decKey key mp = case IM.lookup key mp of
Nothing -> (mp, True)
Just x -> if x <= 1
then (IM.delete key mp, True)
else (IM.insert key (x - 1) mp, False)
Now what’s the core function of the loop? When we “take” a course, we loop through its unlocks, reduce all their degrees, and track which ones are now 0. Since this is a loop that updates state (the remaining inDegrees
), we want to write a folding function for it:
decDegree :: Int -> (DegCount, [Int]) -> (DegCount, [Int])
First we perform the decrement. Then if decKey
returns True
, we’ll add the course to our new0s
list.
decDegree :: Int -> (DegCount, [Int]) -> (DegCount, [Int])
decDegree post (inDegrees', new0s) =
let (inDegrees'', removed) = decKey post inDegrees'
in (inDegrees'', if removed then (post : new0s) else new0s)
Writing the Loop
With all these helpers at our disposal, we can finally write our core loop. Recall the 3 parts of our loop state: the number of courses taken so far, the queue of 0-degree courses, and the in-degree values. This loop should just return the number of courses completed:
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = ...
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = ...
If the queue is empty, we just return our accumulated number. While we’re at it, the final action is to simply compare this loop result to total number of courses to get our final result:
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) -> ...
We we “pop” the first course
off of the queue, we first get the list of “post” courses that could now be unlocked by this course. Then we can apply our decDegree
helper to get the final inDegrees’’
map and the “new 0’s”.
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) ->
let posts = fromMaybe [] (IM.lookup course unlocks)
(inDegrees'', new0s) = foldr decDegree (inDegrees', []) posts
...
Finally, we append the new 0’s to the end of the queue, and we make our recursive call, completing the loop and the function!
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
loop :: CourseState -> Int
loop (numSatisfied, queue’, inDegrees’) = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) ->
let posts = fromMaybe [] (IM.lookup course unlocks)
(inDegrees'', new0s) = foldr decDegree (inDegrees', []) posts
queue'' = foldl (Seq.|>) rest new0s
in loop (numSatisfied + 1, queue'', inDegrees'')
Here’s the full solution, from start to finish:
type DegCount = IM.IntMap Int
type CourseMaps = (DegCount, IM.IntMap [Int])
type CourseState = (Int, Seq.Seq Int, DegCount)
canFinishCourses :: Int -> [(Int, Int)] -> Bool
canFinishCourses numCourses prereqs = loop (0, queue, inDegrees) == numCourses
where
incKey :: Int -> DegCount -> DegCount
incKey k mp = case IM.lookup k mp of
Nothing -> IM.insert k 1 mp
Just x -> IM.insert k (x + 1) mp
appendUnlock :: Int -> Int -> IM.IntMap [Int] -> IM.IntMap [Int]
appendUnlock pre post mp = case IM.lookup pre mp of
Nothing -> IM.insert pre [post] mp
Just prev -> IM.insert pre (post : prev) mp
processPrereq :: (Int, Int) -> CourseMaps -> CourseMaps
processPrereq (post, pre) (inDegrees', unlocks') =
(incKey post inDegrees', appendUnlock pre post unlocks')
(inDegrees, unlocks) = foldr processPrereq (IM.empty, IM.empty) prereqs
queue = Seq.fromList
(filter (`IM.notMember` inDegrees) [0..numCourses-1])
decKey :: Int -> DegCount -> (DegCount, Bool)
decKey key mp = case IM.lookup key mp of
Nothing -> (mp, True)
Just x -> if x <= 1
then (IM.delete key mp, True)
else (IM.insert key (x - 1) mp, False)
decDegree :: Int -> (DegCount, [Int]) -> (DegCount, [Int])
decDegree post (inDegrees', new0s) =
let (inDegrees'', removed) = decKey post inDegrees'
in (inDegrees'', if removed then (post : new0s) else new0s)
loop :: CourseState -> Int
loop (numSatisfied, queue', inDegrees') = case Seq.viewl queue' of
Seq.EmptyL -> numSatisfied
(course Seq.:< rest) ->
let posts = fromMaybe [] (IM.lookup course unlocks)
(inDegrees'', new0s) = foldr decDegree (inDegrees', []) posts
queue'' = foldl (Seq.|>) rest new0s
in loop (numSatisfied + 1, queue'', inDegrees'')
Conclusion
This problem showed the challenge of working with multiple mutable types in Haskell loops. You have to be very diligent about tracking what pieces are mutable, and you often need to write a lot of helper functions to keep your code clean. In our course, Solve.hs, you’ll learn about writing compound data structures to help you solve problems more cleanly. A Graph
is one example, and you’ll also learn about occurrence maps, which we could have used in this problem.
That’s all for graphs right now. In the next couple weeks, we’ll cover the Trie, a compound data structure that can help with some very specific problems.
Graph Algorithms in Board Games!
For last week’s problem we started learning about graph algorithms, focusing on depth-first-search. Today we’ll do a problem from an old board game that will require us to use breadth-first-search. We’ll also learn about a special library in Haskell that lets us solve these types of problems without needing to implement all the details of these algorithms.
To learn more about this library and graph algorithms in Haskell, you should check out our problem solving course, Solve.hs! Module 3 of the course focuses on algorithms, with a special emphasis on graph algorithms!
The Problem
Today’s problem comes from a kids board game called Snakes and Ladders, which will take a little bit to explain. First, we imagine we have a square board in an N x N grid, where each cell is numbered 1 to N^2. The bottom left corner is always “1”, and numbers increase in a snake-like fashion. First the increase from left to right along the bottom row. Then they go from right to left in the next row, before reversing again. Here’s what the numbers look like for a 6x6 board:
36 35 34 33 32 31
25 26 27 28 29 30
24 23 22 21 20 19
13 14 15 16 17 18
12 11 10 9 8 7
1 2 3 4 5 6
The “goal” is to reach the highest numbered tile, which is either in the top left (for even grid sizes) or the top right (for odd grid sizes). One moves by rolling a 6-sided die. Given the number on the die, you are entitled to move that many spaces. The ordinary path of movement is following the increasing numbers.
As is, the game is a little boring. You just always want to roll the highest number you can. However, various cells on the grid are equipped with “snakes” or “ladders”, which can move you around the grid if your die roll would cause your turn to end where these items start. Ladders typically move you closer to the goal, snakes typically move you away from the goal. Here’s an illustrative picture of a board:
We can represent such a board by putting an integer on each cell. The integer -1
represents an ordinary cell, where you would simply proceed to the next cell in order. However, we can represent the start of each snake and ladder with a number corresponding to the cell number where you end up if your die role lands you there. Here’s an example:
-1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1
-1 35 -1 -1 13 -1
-1 -1 -1 -1 -1 -1
-1 15 -1 -1 -1 -1
This grid has two ladders. The first can take you from position 2 to position 15 (see the bottom left corner). The second can take you from position 14 to position 35. There is also a snake, that will take you back from position 17 to position 13. Note that no matter the layout, you can only follow one snake or ladder on a turn. If you end your turn at the beginning of a ladder, which takes you to the beginning of another ladder, you do not take the second ladder on that turn.
Our objective is to find the smallest number of dice rolls possible to reach the goal cell (which will always have -1). In this case, the answer is 4. Various combinations of 3 rolls can land us on 14, which will take us to 35. Then rolling 1 would take us to the goal.
It is possible to contrive a board where it is impossible to reach the goal! We need to handle these cases. In these situations we must return -1. Here is such a grid, with many snakes, all leading back to the start!
1 1 -1
1 1 1
-1 1 1
The Algorithm
This is a graph search problem where each step we take carries the same weight (one turn), and we are trying to find the shortest path. This makes it a canonical example of a Breadth First Search problem (BFS).
We solve BFS by maintaining a queue of search states. In our case, the search state might consist simply of our location, though we may also want to track the number of steps we needed to reach that location as part of the state.
We’ll have a single primary loop, where we remove the first element in our queue. We’ll find all its “neighbors” (the states reachable from that node), and place these on the end of the queue. Then we’ll continue processing.
BFS works out so that states with a lower “cost” (i.e. number of turns) will all be processed before any states with higher cost. This means that the first time we dequeue a goal state from our queue, we can be sure we have found the shortest path to that goal state.
As with last week’s problem, we’ll spend a fair amount of effort on our “neighbors” function, which is often the core of a graph solution. Once we have that in place, the mechanics of the graph search generally become quite easy.
Rust Solution
Once again we’ll start with Rust, because we’ll use a special trick in Haskell. As stated, we want to start with our neighbors function. We’ll represent a single location just using the integer representing it on the board, not its grid coordinates. So at root, we’re taking one usize
and returning a vector of usize
values. But we’ll also take the board (a 2D vector of integers) so we can follow the snakes and ladders. Finally, we’ll pass the size of the board (just N, since our board is always square) and the “goal” location so that we don’t have to recalculate these every time:
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
...
return results;
}
The basic idea of this function is that we’ll loop through the possible die rolls (1 to 6) and return the resulting location from each roll. If we find that the roll would take us past the goal, then we can safely break:
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
for i in 1..=6 {
if loc + i > goal {
break;
}
...
}
return results;
}
How do we actually get the resulting location? We need to use the board, but in order to use the board, we have to convert the location into 2D coordinates. So let’s just write the frame for a function converting a location into coordinates. We’ll fill it in later:
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
...
}
Assuming we have this function, the rest of our neighbors
logic is easy. We check the corresponding value for the location in board
. If it is -1
, we just use our prior location added to the die roll. Otherwise, we use the location given in the cell:
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
for i in 1..=6 {
if loc + i > goal {
break;
}
let (row, col) = convert(n, loc + i);
let next = board[row][col];
if next == -1 {
results.push(loc + i);
} else {
results.push(next as usize);
}
}
return results;
}
So let’s fill in this conversion function. It’s tricky because of the snaking order of the board and because we start from the bottom (highest row index) and not the top. Nonetheless, we want to start by getting the quotient and remainder of our location with the side-length. (We subtract 1 since our locations are 1-indexed).
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
let rowBase = (loc - 1) / n;
let colBase = (loc - 1) % n;
...
}
To get the final row, we simply take n - rowBase - 1
. The column is trickier. We need to consider if the row base is even or odd. If it is even, the row is going from left to right. Otherwise, it goes from right to left. In the first case, the modulo for the column gives us the right column. In the second case, we need to subtract from n
like we did with rows.
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
let rowBase = (loc - 1) / n;
let colBase = (loc - 1) % n;
let row = n - rowBase - 1;
let col =
if rowBase % 2 == 0 {
colBase
} else {
n - colBase - 1
};
return (row, col);
}
But that’s all we need for conversion!
Now that our neighbors
function is closed up, we can finally write the core solution. For the Rust solution, we’ll define our “search state” as including the location and the number of steps we took to reach it, so a tuple (usize, usize)
. We’ll create a VecDeque
of these, which is Rust’s structure for a queue, and insert our initial state (location 1, count 0):
use std::collections::VecDeque;
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let n = board.len();
let goal = board.len() * board[0].len();
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
...
}
We also want to track the locations we’ve already visited. This will be a hash set of the locations but not the counts. This is necessary to prevent infinite loops. Once we’ve visited a location there is no advantage to considering it again on a later branch (with this problem at least). We’ll also follow the practice of considering a cell “visited” once it is enqueued.
use std::collections::VecDeque;
use std::collections::HashSet;
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let n = board.len();
let goal = board.len() * board[0].len();
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let mut visited = HashSet::new();
visited.insert(1);
...
}
Now we’ll run a loop popping the front of the queue and finding the “neighboring” locations. If our queue is empty, this indicates no path was possible, so we return -1.
use std::collections::VecDeque;
use std::collections::HashSet;
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let n = board.len();
let goal = board.len() * board[0].len();
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let mut visited = HashSet::new();
visited.insert(1);
while let Some((idx, count)) = queue.pop_front() {
let ns = neighbors(n, goal, &board, idx);
...
}
return -1;
}
Now processing each neighbor is simple. First, if the neighbor is the goal, we’re done! Just return the dequeued count plus 1. Otherwise, check if we’ve visited the neighbor before. If not, push it to the back of the queue, along with an increased count:
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let n = board.len();
let goal = board.len() * board[0].len();
let mut visited = HashSet::new();
visited.insert(1);
while let Some((idx, count)) = queue.pop_front() {
let ns = neighbors(n, goal, &board, idx);
for next in ns {
if next == goal {
return (count + 1) as i32;
}
if !visited.contains(&next) {
queue.push_back((next, count + 1));
visited.insert(next);
}
}
}
return -1;
}
This completes our BFS solution! Here is the complete code:
use std::collections::VecDeque;
use std::collections::HashSet;
pub fn convert(n: usize, loc: usize) -> (usize, usize) {
let rowBase = (loc - 1) / n;
let colBase = (loc - 1) % n;
let row = n - rowBase - 1;
let col =
if rowBase % 2 == 0 {
colBase
} else {
n - colBase - 1
};
return (row, col);
}
pub fn neighbors(n: usize, goal: usize, board: &Vec<Vec<i32>>, loc: usize) -> Vec<usize> {
let mut results = Vec::new();
for i in 1..=6 {
if loc + i > goal {
break;
}
let (row, col) = convert(n, loc + i);
let next = board[row][col];
if next == -1 {
results.push(loc + i);
} else {
results.push(next as usize);
}
}
return results;
}
pub fn snakes_and_ladders(board: Vec<Vec<i32>>) -> i32 {
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
queue.push_back((1,0));
let n = board.len();
let goal = board.len() * board[0].len();
let mut visited = HashSet::new();
visited.insert(1);
while let Some((idx, count)) = queue.pop_front() {
let ns = neighbors(n, goal, &board, idx);
for next in ns {
if next == goal {
return (count + 1) as i32;
}
if !visited.contains(&next) {
queue.push_back((next, count + 1));
visited.insert(next);
}
}
}
return -1;
}
Haskell Solution
For our Haskell solution, we’re going to use a special shortcut. We’ll make use of the Algorithm.Search library to handle the mechanics of the BFS for us. The function we’ll use has this type signature (slightly simplified):
bfs :: (state -> [state]) -> (state -> Bool) -> state -> Maybe [state]
We provide 3 inputs. First is the “neighbors” function, taking one state and returning its neighbors. Second is the “goal” function, telling us if a state is our final goal state. Finally we give it the initial state. If a goal is reachable, we receive a path to that goal. If not, we receive Nothing
. Since this library provides the full path for us automatically, we won’t track the number of steps in our state. Our “state” will simply be the location. So let’s begin by framing out our function:
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert :: Int -> (Int, Int)
neighbor :: Int -> Int
neighbors :: Int -> [Int]
Let’s start with convert
. This follows the same rules we used in our Rust solution, so there’s not much to say here. We just have to make sure we account for non-zero start indices in Haskell arrays by adding minRow
and minCol
.
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert :: Int -> (Int, Int)
convert loc =
let (rowBase, colBase) = (loc - 1) `quotRem` n
row = minRow + (n - rowBase - 1)
col = minCol + if even rowBase then colBase else n - colBase - 1
in (row, col)
Now we’ll write a neighbor
helper that converts a single location. This just makes our neighbors
function a lot cleaner. We use the same logic of checking for -1
in the board, or else using the value we find there.
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert = ...
neighbor :: Int -> Int
neighbor loc =
let coord = convert loc
onBoard = board A.! coord
in if onBoard == -1 then loc else onBoard
Now we can write neighbors
with a simple list comprehension. We look through each roll of 1-6, add it to the current location, filter if this location is past the goal, and then calculate the neighbor.
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = ...
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert = ...
neighbor = ...
neighbors :: Int -> [Int]
neighbors loc =
[neighbor (loc + i) | i <- [1..6], loc + i <= goal]
Now for the coup-de-grace. We call bfs
with our neighbors
function. The “goal” function is just (== goal)
, and the starting state is just 1
. It will return our shortest path, and so we just return its length:
snakesAndLadders :: A.Array (Int, Int) Int -> Int
snakesAndLadders board = case bfs neighbors (== goal) 1 of
Nothing -> (-1)
Just path -> length path
where
((minRow, minCol), (maxRow, _)) = A.bounds board
n = maxRow - minRow + 1
goal = n * n
convert :: Int -> (Int, Int)
convert loc =
let (rowBase, colBase) = (loc - 1) `quotRem` n
row = minRow + (n - rowBase - 1)
col = minCol + if even rowBase then colBase else n - colBase - 1
in (row, col)
neighbor :: Int -> Int
neighbor loc =
let coord = convert loc
onBoard = board A.! coord
in if onBoard == -1 then loc else onBoard
neighbors :: Int -> [Int]
neighbors loc =
[neighbor (loc + i) | i <- [1..6], loc + i <= goal]
And that’s our complete Haskell solution!
Conclusion
If you take our Solve.hs course, Module 3 is your go-to for learning about graph algorithms! You’ll implement BFS from scratch in Haskell, and learn how to apply other helpers from Algorithm.Search. In next week’s article, we’ll do one more graph problem that goes beyond the basic ideas of DFS and BFS.
Starting out with Graph Algorithms: Basic DFS
For a few weeks now, we’ve been tackling problems related to data structures, with a sprinkling of algorithmic ideas in there. Last week, we covered sets and heaps. Prior to that, we considered Matrices and the binary search algorithm.
This week, we’ll cover our first graph problem! Graph problems often build on a lot of fundamental layers. You need to understand the algorithm itself. Then you need to use the right data structures to apply it. And you’ll also still need the core problem solving patterns at your disposal. These 3 areas correspond to the first 3 modules of Solve.hs, our Haskell problem solving course! Check out that course to level up your Haskell skills!
The Problem
Today’s problem is called Number of Islands. We are given a 2D array as input, where every cell is either “land” or “sea” (either represented as the characters 1
and 0
, or True
and False
). We want to find the number of distinct islands in this grid. Two “land” cells are part of the same island if we can draw a path from one cell to the other that only uses other land cells and other travels up, down, left and right (but not diagonally).
Let’s suppose we have this example:
111000
100101
100111
110000
000110
This grid has 3 islands. The island in the top left corner comprises 7 connected cells. Then there’s a small island in the bottom right with only 2 cells. Finally, we have a third island in the middle right with 5 tiles. While it is diagonally adjacent to the first island, we do not count this as a connection.
The Algorithm
This is one of the most basic questions you’ll see that requires a graph search algorithm, like Depth-First-Search (DFS) or Breadth-First-Search (BFS). The basic principle is that we will select a starting coordinate for a search. We will use one of these algorithms to find all the land cells that are part of that cell’s island. We’ll then increment a counter for having found this island.
We need to track all the cells that are part of this island. We’ll then keep iterating for new start locations to find new islands, but we have to exclude any locations that have already been explored.
While BFS is certainly possible, the solutions we’ll write here will use DFS. Our solution will consist of 3 components:
- A “neighbors” function that finds all adjacent land tiles to a given tile.
- A “visit” function that will take a starting coordinate and populate a “visited” set with all of the cells on the same island as the starting coordinate.
- A core “loop” that will consider each coordinate as a possible starting value for an island.
This ordering represents more of a “bottom up” approach to solving the problem. Going “top down” also works, and may be easier if you’re unfamiliar with graph algorithms. But as you get more practice with them, you’ll get a feel for knowing the bottom layers you need right away.
Rust Solution
To write our solution, we’ll write the components in our bottom up order. We’ll start with our neighbors
function. This will take the island grid (LeetCode supplies us with a Vec<Vec<char>>
) and the current location. We’ll represent locations as tuples, (usize, usize)
. This function will return a vector of locations.
use std::collections::HashSet;
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
...
}
We’ll start this function by defining a few values. We want to know the length and width of the grid, as well as defining r
and c
to quickly reference the current location.
use std::collections::HashSet;
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
let m = grid.len();
let n = grid[0].len();
let r = loc.0;
let c = loc.1;
let mut result: Vec<(usize,usize)> = Vec::new();
...
}
Now we just have to look in each of the four directions. Each direction is included as long as it is a land tile and that it is not out of bounds. We’ll do our “visited” checks elsewhere.
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
let m = grid.len();
let n = grid[0].len();
let r = loc.0;
let c = loc.1;
let mut result: Vec<(usize,usize)> = Vec::new();
if (r > 0 && grid[r - 1][c] == '1') {
result.push((r - 1, c));
}
if (c > 0 && grid[r][c - 1] == '1') {
result.push((r, c - 1));
}
if (r + 1 < m && grid[r + 1][c] == '1') {
result.push((r + 1, c));
}
if (c + 1 < n && grid[r][c + 1] == '1') {
result.push((r, c + 1));
}
return result;
}
Now let’s write the visit
function. Remember, this function’s purpose is to populate the visited
set starting from a certain location. We’ll use a HashSet
of tuples for the visited set.
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
...
}
First, we’ll check if this location is already visited and return if so. Otherwise we’ll insert it.
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
if (visited.contains(loc)) {
return;
visited.insert(*loc);
...
}
All we have to do now is find the neighbors of this location, and recursively “visit” each one of them!
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
if (visited.contains(loc)) {
return;
}
visited.insert(*loc);
let ns = neighbors(grid, visited, loc);
for n in ns {
visit(grid, &n);
}
}
We’re not quite done, as we have to loop through our grid to call this function on each possible start. This isn’t so bad though. We start our function by defining key terms.
pub fn num_islands(grid: Vec<Vec<char>>) -> i32 {
let m = grid.len();
let n = grid[0].len();
let mut visited: HashSet<(usize,usize)> = HashSet::new();
let mut islandCount = 0;
...
// islandCount will be our final result
return islandCount;
}
Now we’ll “loop” through each possible starting location:
pub fn num_islands(grid: Vec<Vec
The last question is, what do we do for each location? If the location is land AND it is still unvisited, we treat it as the start of a new island. This means we increase the island count and then “visit” the location. When we consider other cells on this island, they’re already visited, so we won’t increase the island count when we find them!
```rust
pub fn num_islands(grid: Vec<Vec<char>>) -> i32 {
let m = grid.len();
let n = grid[0].len();
let mut visited: HashSet<(usize,usize)> = HashSet::new();
let mut islandCount = 0;
for row in 0..m {
for col in 0..n {
islandCount += 1;
visit(&grid, &mut visited, &loc);
}
return islandCount;
}
Here’s our complete solution:
use std::collections::HashSet;
pub fn neighbors(
grid: &Vec<Vec<char>>,
loc: &(usize,usize)) -> Vec<(usize,usize)> {
let m = grid.len();
let n = grid[0].len();
let r = loc.0;
let c = loc.1;
let mut result: Vec<(usize,usize)> = Vec::new();
if (r > 0 && grid[r - 1][c] == '1') {
result.push((r - 1, c));
}
if (c > 0 && grid[r][c - 1] == '1') {
result.push((r, c - 1));
}
if (r + 1 < m && grid[r + 1][c] == '1') {
result.push((r + 1, c));
}
if (c + 1 < n && grid[r][c + 1] == '1') {
result.push((r, c + 1));
}
return result;
}
pub fn visit(
grid: &Vec<Vec<char>>,
visited: &mut HashSet<(usize,usize)>,
loc: &(usize,usize)) {
if (visited.contains(loc)) {
return;
}
visited.insert(*loc);
let ns = neighbors(grid, visited, loc);
for n in ns {
visit(grid, &n);
}
}
pub fn num_islands(grid: Vec<Vec<char>>) -> i32 {
let m = grid.len();
let n = grid[0].len();
let mut visited: HashSet<(usize,usize)> = HashSet::new();
let mut islandCount = 0;
for row in 0..m {
for col in 0..n {
let loc: (usize,usize) = (row,col);
if grid[row][col] == '1' && !(visited.contains(&loc)) {
islandCount += 1;
visit(&grid, &mut visited, &loc);
}
}
}
return islandCount;
}
Haskell Solution
The structure of this solution translates well to Haskell, since it’s a recursive solution at its root. We’ll just use a couple of folds to handle the other loops. Let’s outline the solution:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
Since we’re writing our functions “in line”, we don’t need to pass the grid around like we did in our Rust solution (though inline functions are also possible there). What you should observe immediately is that visit
and loop
have a similar structure. They both fit into the a -> b -> b
pattern we want for foldr
! We’ll use this to great effect!
But first, let’s fill in neighbors
. Each of the 4 directions requires the same two conditions we used before. We make sure it’s not out of bounds, and that the next tile is “land”. Here’s how we check the “up” direction:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
neighbors (row, col) =
let up = if row > minRow && grid A.! (row - 1, col) then Just (row - 1, col) else Nothing
...
We return Nothing
if it is not a valid neighbor. Then we just combine the four directional options with catMaybes
to complete this helper:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
neighbors (row, col) =
let up = if row > minRow && grid A.! (row - 1, col) then Just (row - 1, col) else Nothing
left = if col > minCol && grid A.! (row, col - 1) then Just (row, col - 1) else Nothing
down = if row < maxRow && grid A.! (row + 1, col) then Just (row + 1, col) else Nothing
right = if col < maxCol && grid A.! (row, col + 1) then Just (row, col + 1) else Nothing
in catMaybes [up, left, down, right]
Now we start the visit
function by checking if we’ve already visited the location, and add it to the set if not:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
visit coord visited = if HS.member coord visited then visited else
let visited' = HS.insert coord visited
...
Now we have to get the neighbors and “loop” through the neighbors so that we keep the visited
set updated. This is where we’ll apply our first fold. We’ll recursively fold over visit
on each of the possible neighbors, which will give us the final visited set from this process. That’s all we need for this helper!
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
visit coord visited = if HS.member coord visited then visited else
let visited' = HS.insert coord visited
in foldr visit visited’ ns
Now our loop
function will consider only a single coordinate. We think of this as having two pieces of state. First, the number of accumulated islands (the Int
). Second, we have the visited set. So we check if the coordinate is unvisited land. If so, we increase the count, and get our “new” visited set by calling visit
on it. If not, we return the original inputs.
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = ...
where
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
loop coord (count, visited) = if grid A.! coord && not (HS.member coord visited)
then (count + 1, visit coord visited)
else (count, visited)
Now for the final flourish. Our loop
also has the structure for foldr
. So we’ll loop
over all the indices of our array, which will give us the final number of islands and the visited set. Our final answer is just the fst
of these:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = fst (foldr loop (0, HS.empty) (A.indices grid))
where
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
Here’s our final solution:
numberOfIslands :: A.Array (Int, Int) Bool -> Int
numberOfIslands grid = fst (foldr loop (0, HS.empty) (A.indices grid))
where
((minRow, minCol), (maxRow, maxCol)) = A.bounds grid
neighbors :: (Int, Int) -> [(Int, Int)]
neighbors (row, col) =
let up = if row > minRow && grid A.! (row - 1, col) then Just (row - 1, col) else Nothing
left = if col > minCol && grid A.! (row, col - 1) then Just (row, col - 1) else Nothing
down = if row < maxRow && grid A.! (row + 1, col) then Just (row + 1, col) else Nothing
right = if col < maxCol && grid A.! (row, col + 1) then Just (row, col + 1) else Nothing
in catMaybes [up, left, down, right]
visit :: (Int, Int) -> HS.HashSet (Int, Int) -> HS.HashSet (Int, Int)
visit coord visited = if HS.member coord visited then visited else
let visited' = HS.insert coord visited
ns = neighbors coord
in foldr visit visited' ns
loop :: (Int, Int) -> (Int, HS.HashSet (Int, Int)) -> (Int, HS.HashSet (Int, Int))
loop coord (count, visited) = if grid A.! coord && not (HS.member coord visited)
then (count + 1, visit coord visited)
else (count, visited)
A Note on the Graph Algorithm
It seems like we solved this without even really apply a “graph” algorithm! We just did a loop and a recursive call and everything worked out! There are a couple elements of this problem that make it one of the easiest graph problems out there.
Normally, we have to use some kind of structure to store a search state, telling us the next nodes in our graph to search. For BFS, this is a queue. For Dijkstra or A*, it is a heap. For DFS it is normally a stack. However, we are using the call stack to act as the stack for us!
When we make a recursive call to “visit” a location, we don’t need to keep track of which node we return to after we’re done. The function returns, and the prior node is already sitting there on the call stack.
The other simplifying factor is that we don’t need to do any backtracking or state restoration. Sometimes with a DFS, you need to “undo” some of the steps you took if you don’t find your goal. But this algorithm is just a space-filling algorithm. We are just trying to populate our “visited” set, and we never take nodes out of this set once we have visited them.
Conclusion
We’ve got a couple more graph problems coming up next. If you want to learn more about applying graph algorithms in Haskell (including implementing them for yourself!) check out Solve.hs, where Module 3 will teach you about algorithms including DFS, BFS, A* and beyond!
Sets & Heaps in Haskell and Rust
In the last few articles, we’ve spent some time doing manual algorithms with binary trees. But most often, you won’t have to work with binary trees yourself, as they are built in to the operations of other data structures.
Today, we’ll solve a problem that relies on using the “set” data structure as well as the “heap” data structure. To learn more details about using these structures in Haskell, sign up for our Solve.hs course, where you’ll spend an entire module on data structures!
The Problem
Today’s problem is Find k pairs with smallest sums. Our problem input consists of two sorted arrays of integers as well as a number k
. Our job is to return the k
pairs of numbers that have the smallest sum, where a “pair” consists of one number from the first array, and one number from the second array.
Here’s an example input and output:
inputNums1 = [1,10,101,102]
inputNums2 = [7,8,9,10,11]
k = 11
output =
[ (1,7), (1,8), (1,9), (1,10), (1,11)
, (10,7), (10,8), (10,9), (10,10)
, (10,11), (101,7)
]
Observe that we are returning the numbers themselves in pairs, rather than the sums, and not the indices of the numbers.
While the (implicit) indices of each pair must be unique, the numbers do not need to be unique. For each, if we have the lists [1,1], [5, 7]
, and k
is 2, we should return [(1,5), (1,5)]
, where both of the 1’s from the first list are paired with the 5 from the second list.
The Algorithm
At the center of our algorithm is a min-heap. We want to store elements that contain pairs of numbers. But we want those elements to be ordered by their sum. We also want the element to contain the corresponding indices for the each number in their array.
We’ll start by making a pair of the first element from each array, and inserting that into our heap, because this pair must be the smallest. Then we’ll do a loop where we extract the minimum from the heap, and then try adding its “neighbors” into the heap. A “neighbor” of the index pair (i1, i2)
comes from incrementing one of the two indices, so either (i1 + 1, i2)
or (i1, i2 + 1)
. We continue until we have extracted k
elements (or our heap is empty).
Each time we insert a pair of numbers into the heap, we’ll insert the pair of indices into a set. This will allow us to avoid double counting any pairs.
So using the first example in the section above, here are the first few steps of our process. Inside the heap is a 3-tuple with the sum, the indices (a 2-tuple), and the values (another 2-tuple).
Step 1:
Heap: [(8, (0,0) (1, 7))]
Visited: [(0,0)]
Output: []
Step 2:
Heap: [(9, (0,1), (1,8)), ]
Visited: [(0,0), (0,1), (1,0)]
Output: [(1,7)]
Step 3:
Heap: [(10, (0,2), (1,9)), (17, (1,0), (10,7)), (18, (1,1), (10,8))]
Visited: [(0,0), (0,1), (0,2), (1,0), (1,1)]
Output: [(1,7), (1,8)]
And so we would continue this process until we gathered k
outputs.
Haskell Solution
The core structure of this problem isn’t hard. We define some initial terms, such as our data structures and outputs, and then we have a single loop. We’ll express this single loop as a recursive function. It will take the number of remaining values, the visited set, the heap, and the accumulated output, and ultimately return the output. Throughout this problem we’ll use the type alias I2
to refer to a tuple of two integers.
import qualified Data.Heap as H
import qualified Data.Set as S
type I2 = (Int, Int)
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
n1 = V.length arr1
n2 = V.length arr2
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f remaining visited heap acc = ...
Let’s fill in our loop. We can start with the edge cases. If the k
is 0, or if our heap is empty, we should return the results list (in reverse).
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) -> ...
Our primary case now results from extracting the min element of the heap. We don’t actually need its sum. We just put that in the first position of the tuple so that it is the primary sorting value for the heap. Let’s define the next possible coordinate pairs from adding to i1
and i2
, as well as the new sums we get from using those indices:
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
in ...
Now we need to try adding these values to the remaining heap. If the index is too large, or if we’ve already visited the coordinate, we don’t add the new value, returning the old heap. Otherwise we insert
it into our heap. For the second value, we just use heap1
(from trying to add the first value) as the baseline.
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
heap1 = if i1 + 1 < n1 && S.notMember c1 visited
then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
heap2 = if i2 + 1 < n2 && S.notMember c2 visited
then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
in ...
Now we complete our recursive loop function, by adding these new indices to the visited
set and make a recursive call. We decrement the remaining number, and append the values to our accumulated list.
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
where
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
heap1 = if i1 + 1 < n1 && S.notMember c1 visited
then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
heap2 = if i2 + 1 < n2 && S.notMember c2 visited
then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
visited' = foldr S.insert visited ([c1, c2] :: [I2])
in f (remaining - 1) visited' heap2 ((v1,v2) : acc)
To complete the function, we define our initial heap and make the first call to our loop function:
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = f k (S.singleton (0,0)) initialHeap []
where
val1 = arr1 V.! 0
val2 = arr2 V.! 0
initialHeap = H.singleton (val1 + val2, (0,0), (val1, val2))
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f = ...
Now we’re done! Here’s our complete solution:
type I2 = (Int, Int)
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = f k (S.singleton (0,0)) initialHeap []
where
val1 = arr1 V.! 0
val2 = arr2 V.! 0
initialHeap = H.singleton (val1 + val2, (0,0), (val1, val2))
n1 = V.length arr1
n2 = V.length arr2
f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
f 0 _ _ acc = reverse acc
f remaining visited heap acc = case H.view heap of
Nothing -> reverse acc
Just ((_, (i1, i2), (v1, v2)), restHeap) ->
let c1 = (i1 + 1, i2)
c2 = (i1, i2 + 1)
inc1 = arr1 V.! (i1 + 1) + v2
inc2 = v1 + arr2 V.! (i2 + 1)
heap1 = if i1 + 1 < n1 && S.notMember c1 visited
then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
heap2 = if i2 + 1 < n2 && S.notMember c2 visited
then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
visited' = foldr S.insert visited ([c1, c2] :: [I2])
in f (remaining - 1) visited' heap2 ((v1,v2) : acc)
Rust Solution
Now, on to our Rust solution. We’ll start by defining our terms. These follow the pattern laid out in our algorithm and the Haskell solution:
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
let mut heap: BinaryHeap<Reverse<(i32, (usize,usize), (i32,i32))>> = BinaryHeap::new();
let val1 = nums1[0];
let val2 = nums2[0];
heap.push(Reverse((val1 + val2, (0,0), (val1, val2))));
let mut visited: HashSet<(usize, usize)> = HashSet::new();
visited.insert((0,0));
let mut results = Vec::new();
let mut remaining = k;
let n1 = nums1.len();
let n2 = nums2.len();
...
return results;
}
The most interesting of these is the heap
. We parameterize the type of the BinaryHeap
using the same kind of tuple we had in Haskell. But in order to make it a “Min” heap, we have to wrap our values in the Reverse
type.
Now let’s define the outline of our loop. We keep going until remaining
is 0. We will also break if we can’t pop
a value from the heap.
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
...
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
...
} else {
break;
}
remaining -= 1;
}
return results;
}
Now we’ll define our new coordinates, and add tests for whether or not these can be added to the heap:
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
...
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
let c1 = (i1 + 1, i2);
let c2 = (i1, i2 + 1);
if i1 + 1 < n1 && !visited.contains(&c1) {
let inc1 = nums1[i1 + 1] + v2;
...
}
if i2 + 1 < n2 && !visited.contains(&c2) {
let inc2 = v1 + nums2[i2 + 1];
...
}
} else {
break;
}
remaining -= 1;
}
return results;
}
In each case, we add the new coordinate to visited
, and push the new element on to the heap. We also push the values onto our results
array.
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
...
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
let c1 = (i1 + 1, i2);
let c2 = (i1, i2 + 1);
if i1 + 1 < n1 && !visited.contains(&c1) {
let inc1 = nums1[i1 + 1] + v2;
visited.insert(c1);
heap.push(Reverse((inc1, c1, (nums1[i1 + 1], v2))));
}
if i2 + 1 < n2 && !visited.contains(&c2) {
let inc2 = v1 + nums2[i2 + 1];
visited.insert(c2);
heap.push(Reverse((inc2, c2, (v1, nums2[i2 + 1]))));
}
results.push(vec![v1,v2]);
} else {
break;
}
remaining -= 1;
}
return results;
}
And that’s all we need!
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
let mut heap: BinaryHeap<Reverse<(i32, (usize,usize), (i32,i32))>> = BinaryHeap::new();
let val1 = nums1[0];
let val2 = nums2[0];
heap.push(Reverse((val1 + val2, (0,0), (val1, val2))));
let mut visited: HashSet<(usize, usize)> = HashSet::new();
visited.insert((0,0));
let mut results = Vec::new();
let mut remaining = k;
let n1 = nums1.len();
let n2 = nums2.len();
while remaining > 0 {
if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
let c1 = (i1 + 1, i2);
let c2 = (i1, i2 + 1);
if i1 + 1 < n1 && !visited.contains(&c1) {
let inc1 = nums1[i1 + 1] + v2;
visited.insert(c1);
heap.push(Reverse((inc1, c1, (nums1[i1 + 1], v2))));
}
if i2 + 1 < n2 && !visited.contains(&c2) {
let inc2 = v1 + nums2[i2 + 1];
visited.insert(c2);
heap.push(Reverse((inc2, c2, (v1, nums2[i2 + 1]))));
}
results.push(vec![v1,v2]);
} else {
break;
}
remaining -= 1;
}
return results;
}
Conclusion
Next time, we’ll start exploring some graph problems, which also rely on data structures like we used here!
I didn’t explain too much in this article about the details of using these various data structures. If you want an in-depth exploration of how data structures work in Haskell, including the “common API” that helps you use almost all of them, you should sign up for our Solve.hs course! Module 2 is completely dedicated to teaching you about data structures, and you’ll get a lot of practice working with these structures in sample problems.
Binary Tree BFS: Zigzag Order
In our last article, we explored how to perform an in-order traversal of a binary search tree. Today we’ll do one final binary tree problem to solidify our understanding of some common tree patterns, as well as the tricky syntax for dealing with a binary tree in Rust.
If you want some interesting challenge problems using Haskell data structures, you should take our Solve.hs course. In particular, you’ll learn how to write a self-balancing binary tree to use for an ordered set!
The Problem
Today we will solve Zigzag Level Order Traversal. For any binary tree, we can think about it in terms of “levels” based on the number of steps from the root. So given this tree:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
We can visually see that there are 4 levels. So a normal level order traversal would return a list of 4 lists, where each list is a single level, ordered from left to right, visually speaking:
[45]
[32, 50]
[5, 40, 100]
[37, 43]
However, with a zigzag level order traversal, every other level is reversed. So we should get the following result for the input tree:
[45]
[50, 32]
[5, 40, 100]
[43, 37]
So we can imagine that we do the first level from left to right and then zigzag back to get the second level from right to level. Then we do left to right again for the third level, and so on.
The Algorithm
For our in-order traversal, we used a kind of depth-first search (DFS), and this approach is more common for tree-based problems. However, for a level-order problem, we want more of a breadth-first search (BFS). In a BFS, we explore states in order of their distance to the root. Since all nodes in a level have the same distance to the root, this makes sense.
Our general idea is that we’ll store a list of all the nodes from the prior level. Initially, this will just contain the root node. We’ll loop through this list, and create a new list of the values from the nodes in this list. This gets appended to our final result list.
While we’re doing this loop, we’ll also compose the list for the next level. The only trick is knowing whether to add each node’s left or right child to the next-level list first. This flips each iteration, so we’ll need a boolean tracking it that flips each time.
Once we encounter a level that produces no numbers (i.e. it only contains Nil
nodes), we can stop iterating and return our list of lists.
Rust Solution
Now that we’re a bit more familiar with manipulating Rc RefCells, we’ll start with the Rust solution, framing it according to the two-loop structure in our algorithm. We’ll define stack1
, which is the iteration stack, and stack2
, where we accumulate the new nodes for the next layer. We also define our final result vector, a list of lists.
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
let mut result: Vec<Vec<i32>> = Vec::new();
let mut stack1: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
stack1.push(root.clone());
let mut stack2: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
let mut leftToRight = true;
...
return result;
}
Our initial loop will continue until stack1
no longer contains any elements. So our basic condition is while(!stack1.is_empty()
. However, there’s another important element here.
After we accumulate the new nodes in stack2
, we want to flip the meanings of our two stacks. We want our accumulated nodes referred to by stack1
, and stack2
to be an empty list to accumulate. We accomplish this in Rust by clearing stack1
at the end of our loop, and then using std::mem::swap
to flip their meanings:
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
let mut result: Vec<Vec<i32>> = Vec::new();
let mut stack1: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
stack1.push(root.clone());
let mut stack2: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
let mut leftToRight = true;
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new(); // Values from this level
...
leftToRight = !leftToRight;
stack1.clear();
mem::swap(&mut stack1, &mut stack2);
}
return result;
}
In C++ we could accomplish something like this using std::move
, but only because we want stack1
to return to an empty state.
stack2 = std::move(stack1);
Also, observe that we flip our boolean flag at the end of the iteration.
Now let’s get to work on the inner loop. This will actually go through stack1
, add values to thisLayer
, and accumulates the next layer of nodes for stack2
. An interesting finding is that whether we’re going left to right or vice versa, we want to loop through stack2
in reverse. This means we’re treating it like a true stack instead of a vector, first accessing the last node to be added.
A left-to-right pass will add lefts and then rights. This means the right-mode node in the next layer is on “top” of the stack, at the end of the vector. A right-to-left pass will first add the right child for a node before its left. This means the left-most node of the next layer is at the end of the vector.
Let’s frame up this loop, and also add the results of this layer to our final result vector.
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
...
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new()
for node in stack1.iter().rev() {
...
}
if (!thisLayer.is_empty()) {
result.push(thisLayer);
}
leftToRight = !leftToRight;
stack1.clear();
mem::swap(&mut stack1, &mut stack2);
}
return result;
}
Note that we do not add the values array if it is empty. We allow ourselves to accumulate None
nodes in our stack. The final layer we encounter will actually consist of all None
nodes, and we don’t want this layer to add an empty list.
Now all we need to do is populate the inner loop. We only take action if the node from stack1
is Some
instead of None
. Then we follow a few simple steps:
- Borrow the
TreeNode
from thisRefCell
- Push its value onto
thisLayer
. - Add its children (using
clone
) tostack2
, in the right order.
Here’s the code:
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
...
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new()
for node in stack1.iter().rev() {
if let Some(current) = node {
let currentTreeNode = current.borrow();
thisLayer.push(currentTreeNode.val);
if leftToRight {
stack2.push(currentTreeNode.left.clone());
stack2.push(currentTreeNode.right.clone());
} else {
stack2.push(currentTreeNode.right.clone());
stack2.push(currentTreeNode.left.clone());
}
}
}
...
}
return result;
}
And now we’re done! Here’s the full solution:
use std::rc::Rc;
use std::cell::RefCell;
use std::mem;
pub fn zigzag_level_order(root: Option<Rc<RefCell<TreeNode>>>) -> Vec<Vec<i32>> {
let mut result: Vec<Vec<i32>> = Vec::new();
let mut stack1: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
stack1.push(root.clone());
let mut stack2: Vec<Option<Rc<RefCell<TreeNode>>>> = Vec::new();
let mut leftToRight = true;
while (!stack1.is_empty()) {
let mut thisLayer = Vec::new();
for node in stack1.iter().rev() {
if let Some(current) = node {
let currentTreeNode = current.borrow();
thisLayer.push(currentTreeNode.val);
if leftToRight {
stack2.push(currentTreeNode.left.clone());
stack2.push(currentTreeNode.right.clone());
} else {
stack2.push(currentTreeNode.right.clone());
stack2.push(currentTreeNode.left.clone());
}
}
}
if (!thisLayer.is_empty()) {
result.push(thisLayer);
}
leftToRight = !leftToRight;
stack1.clear();
mem::swap(&mut stack1, &mut stack2);
}
return result;
}
Haskell Solution
While our Rust solution was better described from the outside in, it’s easy to build the Haskell solution from the inside out. We have two loops, and we can start by defining the inner loop (we’ll call it the stack loop).
The goal of this loop is to take stack1
and turn it into stack2
(the next layer) and the numbers for this layer, while also tracking the direction of iteration. Both outputs are accumulated as lists, so we have inputs for them as well:
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = ...
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop isLeftToRight stack1 stack2 nums = ...
When stack1
is empty, we return our result from this loop. Because of list accumulation order, we reverse nums
when giving the result. However, we don’t reverse stack2
, because we want to iterate starting from the “top”. This seems like the opposite of what we did in Rust, because Rust uses a vector for its stack type, instead of a singly linked list!
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = ...
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop _ [] stack2 nums = (stack2, reverse nums)
stackLoop isLeftToRight (Nil : rest) stack2 numbers = stackLoop isLeftToRight rest stack2 numbers
stackLoop isLeftToRight (Node x left right : rest) stack2 nums = ...
Observe also a second edge case…for Nil
nodes in stack1
, we just recurse on the rest of the list. Now for the main case, we just define the new stack2
, which adds the child nodes in the correct order. Then we recurse while also adding x
to nums
.
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = ...
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop _ [] stack2 nums = (stack2, reverse nums)
stackLoop isLeftToRight (Nil : rest) stack2 numbers = stackLoop isLeftToRight rest stack2 numbers
stackLoop isLeftToRight (Node x left right : rest) stack2 nums =
let stack2' = if isLeftToRight then right : left : stack2 else left : right : stack2
in stackLoop isLeftToRight rest stack2' (x : nums)
...
Now we’ll define the outer loop, which we’ll call the layerLoop
. This takes the direction flag and stack1
, plus the accumulator list for the results. It also has a simple base case to reverse the results list once stack1
is empty.
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = layerLoop True [root] []
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop = ...
layerLoop :: Bool -> [TreeNode] -> [[Int]] -> [[Int]]
layerLoop _ [] allNums = reverse allNums
layerLoop isLeftToRight stack1 allNums = ...
Now in the recursive case, we call the stackLoop
to get our new numbers and the stack for the next layer (which we now think of as our new stack1
). We then recurse, flipping the boolean flags and adding these new numbers to our results, but only if the list is not empty.
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = layerLoop True [root] []
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop = ...
layerLoop :: Bool -> [TreeNode] -> [[Int]] -> [[Int]]
layerLoop _ [] allNums = reverse allNums
layerLoop isLeftToRight stack1 allNums =
let (stack1', newNums) = stackLoop isLeftToRight stack1 [] []
in layerLoop (not isLeftToRight) stack1' (if null newNums then allNums else newNums : allNums)
The last step, as you seen is calling layerLoop
from the start with root
. We’re done! Here’s our final implementation:
zigzagOrderTraversal :: TreeNode -> [[Int]]
zigzagOrderTraversal root = layerLoop True [root] []
where
stackLoop :: Bool -> [TreeNode] -> [TreeNode] -> [Int] -> ([TreeNode], [Int])
stackLoop _ [] stack2 nums = (stack2, reverse nums)
stackLoop isLeftToRight (Nil : rest) stack2 numbers = stackLoop isLeftToRight rest stack2 numbers
stackLoop isLeftToRight (Node x left right : rest) stack2 nums =
let stack2' = if isLeftToRight then right : left : stack2 else left : right : stack2
in stackLoop isLeftToRight rest stack2' (x : nums)
layerLoop :: Bool -> [TreeNode] -> [[Int]] -> [[Int]]
layerLoop _ [] allNums = reverse allNums
layerLoop isLeftToRight stack1 allNums =
let (stack1', newNums) = stackLoop isLeftToRight stack1 [] []
in layerLoop (not isLeftToRight) stack1' (if null newNums then allNums else newNums : allNums)
Conclusion
That’s all we’ll do for binary trees right now. In the coming articles we’ll continue to explore more data structures as well as some common algorithms. If you want to learn more about data structures in algorithms in Haskell, check out our course Solve.hs. Modules 2 & 3 are filled with this sorts of content, including lots of practice problems.
In-Order Traversal in Haskell and Rust
Last time around, we started exploring binary trees. We began with a simple problem (inverting a tree), but encountered some of the difficulties implementing a recursive data structure in Rust.
Today we’ll do a slightly harder problem (LeetCode rates as “Medium” instead of “Easy”). This problem is also specifically working with a binary search tree instead of a simple binary tree. With a search tree, we have the property that the “values” on each node are orderable, and all the values to the “left” of any given node are no greater than that node’s value, and the values to the “right” are not smaller.
Binary search trees are the heart of any ordered Set
type. In our problem solving course Solve.hs, you’ll get the chance to build a self-balancing binary search tree from scratch, which involves some really cool algorithmic tricks!
The Problem
Though it’s harder than our previous problem, today’s problem is still straightforward. We are taking an ordered binary search tree and finding the k-th smallest element in that tree, where k
is the second input to our function.
So suppose our tree looks like this:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
Our input k
is 1-indexed. So if we get 1
as our input, we should return 5
, the smallest element in the tree. If we receive 4
, we should return 40
, the 4th smallest element after 5, 32 and 37. If we get 8
, we’ll return 100
, the largest element in the tree.
The Algorithm
Binary search trees are designed to give logarithmic time access, insertions, and deletions for elements. If our tree was annotated, so that each node stored the number of children it had, we’d be able to solve this problem in logarithmic time as well.
However, we want to assume a minimal tree design, where each node only holds its own value and pointers to its two children. With these constraints, our algorithm has to be linear in terms of the input k
.
We’re going to solve this with an in-order traversal. We’re just going to traverse the elements of the BST in order from smallest value to largest, and count until we’ve encountered k
elements. Then we’ll return the value at our “current” node.
An in-order traversal is conceptually simple. For a given node, we “visit” the left child, then visit the node itself, and then visit the right child. But the actual mechanics of doing this traversal can be a little tricky to think of on the spot if you haven’t practiced it before.
The main idea is that we’ll use a stack of nodes to track where we are in the tree. The stack traces a path from our “current” node back up its parents back to the root of the tree. Our algorithm looks like this:
First, we create a stack from the root, following all left child nodes until we reach a node without a left child. This node has the smallest element of the tree.
Now, we begin processing, always considering the top node in the stack, while tracking the number of elements remaining until we hit k
. If that number is down to 1
, then the value at the top node in the stack is our result.
If not, we’ll decrement the number of elements remaining, and then check the right child of this node. If the right child is Nil
, we just pop this node from the stack and process its parent. If the right child does exist, we’ll add all of its left children to our stack as well.
Here’s how the stack looks like with our example tree, if we’re looking for k=7
(which should be 50).
[5, 32, 45] -- k = 1, Initial left children of root
[32, 45] -- k = 2, popped 5, no right child
[37, 40, 45] -- k = 3, popped 32, right child was 40, which added left child 37
[40, 45] -- k = 4, popped 37, no right child
[43, 45] -- k = 5, popped 40, and 43 is right child
[45] -- k = 6, popped 43
[50] -- k = 7, popped 45 and added 50, the right child (no left children)
Since 50 is on top of the stack with k = 7
, we can return 50
.
Haskell Solution
Let’s code this up! We’ll start with Haskell, since Rust is, once again, somewhat tricky due to TreeNode
handling. To start, let’s remind ourselves of the recursive TreeNode
type:
data TreeNode = Nil | Node Int TreeNode TreeNode
deriving (Show, Eq)
Now when writing up the algorithm, we first want to define a helper function addLeftNodesToStack
. We’ll use this at the beginning of the algorithm, and then again each time we encounter a right child.
This helper will take a TreeNode
and the existing stack, and return the modified stack.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack = ...
As far as recursive helpers, this is a simple one! If our input node is Nil
, we return the original stack. We want to maintain the invariant that we never include Nil
values in our stack! But if we have a value node, we just add it to the stack and recurse on its left child.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack Nil acc = acc
addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)
Now it’s time to implement our algorithm for finding the k-th element. This will be a recursive function that takes the number of elements remaining, as well as the current stack. We’ll call this initially with k
and the stack we get from adding the left nodes of the root:
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack = ...
findK :: Int -> [TreeNode] -> Int
findK = ...
This function has a couple error cases. We expect a non-empty stack (our input k
is constrained within the size of the tree), and we expect the top to be non-Nil. After that, we have our base case where k = 1
, and we return the value at this node.
Finally, we get our recursive case. We decrement the remaining count, and add the left nodes of the right child of this node to the stack.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack = ...
findK :: Int -> [TreeNode] -> Int
findK k [] = error $ "Found empty list expecting k: " ++ show k
findK _ (Nil : _) = error "Added Nil to stack!"
findK 1 (Node x _ _ : _) = x
findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)
This completes our solution!
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack Nil acc = acc
addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)
findK :: Int -> [TreeNode] -> Int
findK k [] = error $ "Found empty list expecting k: " ++ show k
findK _ (Nil : _) = error "Added Nil to stack!"
findK 1 (Node x _ _ : _) = x
findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)
Rust Solution
In our Rust solution, we’re once again working with this TreeNode
type, including the 3 wrapper layers:
#[derive(Debug, PartialEq, Eq)]
pub struct TreeNode {
pub val: i32,
pub left: Option<Rc<RefCell<TreeNode>>>,
pub right: Option<Rc<RefCell<TreeNode>>>,
}
Our first step will be to implement the helper function to add the “left” nodes. This function will take a “root” node as well as a mutable reference to the stack so we can add nodes to it.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
...
}
You’ll notice that stack
does not actually use the Option
wrapper, only Rc
and RefCell
. Remember in our Haskell solution that we want to enforce that we don’t add non-null nodes to the stack. This Rust solution enforces this constraint at compile time.
To implement this function, we’ll use the same trick we did when inverting trees to pattern match on node
and detect if it is Some
or None
. If it is None
, we don’t have to do anything.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
...
}
}
Since current
is now unwrapped from Option
, we can push it to the stack
. As in our previous problem though, we have to clone
it first! We need a clone of the reference (as wrapped by Rc
because the stack
will now have to own this reference.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
...
}
}
Now we’ll recurse on the left
subchild of current
. In order to unwrap the TreeNode
from Rc/RefCell
, we have to use borrow
. Then we can grab the left
value. But again, we have to clone
it before we make the recursive call. Here’s the final implementation of this helper:
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
add_left_nodes_to_stack(current.borrow().left.clone(), stack);
}
}
We could have implemented the helper with a while
loop instead of recursion. This would actually have used less memory in Rust! We would have to make some changes though, like making a new mut
reference from the root.
Now we can move on to the core function. We’ll start this by defining key terms like our stack
and the number of “remaining” values (initially k
). We’ll also call our helper to get the initial stack.
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
...
}
Now we want to pop
the top element from the stack, and pattern match it as requiring Some
. If there are no more values, we’ll actually panic, because the problem constraints should mean that our stack is never empty. Unlike our helper, we actually will use a while
loop here instead of more recursion:
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
...
}
panic!("k is larger than number of nodes");
}
Now the inside of the loop is simple, following what we’ve done in Haskell. If our remainder is 1, then we have found the correct node. We borrow
the node from the RefCell
and return its value. Otherwise we decrement the count and use our helper on the “right” child of the node we just popped. As usual, the RefCell
wrapper means we need to borrow to get the right
value from the TreeNode
, and then we clone
this child as we pass it to the helper.
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
if remaining == 1 {
return current.borrow().val;
}
remaining -= 1;
add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
}
panic!("k is larger than number of nodes");
}
And that’s it! Here’s the complete Rust solution:
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
add_left_nodes_to_stack(current.borrow().left.clone(), stack);
}
}
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
if remaining == 1 {
return current.borrow().val;
}
remaining -= 1;
add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
}
panic!("k is larger than number of nodes");
}
Conclusion
In-order traversal is a great pattern to commit to memory, as many different tree problems will require you to apply it. Hopefully the details with Rust’s RefCells
are getting more familiar. Next week we’ll do one more problem with binary trees.
If you want to do some deep work with Haskell and binary trees, take a look at our Solve.hs course, where you’ll learn about many different data structures in Haskell, and get the chance to write a balanced binary search tree from scratch!
An Easy Problem Made Hard: Rust & Binary Trees
In last week’s article, we completed our look at Matrix-based problems. Today, we’re going to start considering another data structure: binary trees.
Binary trees are an extremely important structure in programming. Most notably, they are the underlying structure for ordered sets, that allow logarithmic time lookups and insertions. A “tree” is represented by “nodes”, where a node can be “null”, or else hold a value. If it holds a value, it then has a “left” child and a “right” child.
If you take our Solve.hs course, you’ll actually learn to implement an auto-balancing ordered tree set from scratch!
But for these next few articles, we’re going to explore some simple problems that involve binary trees. Today we’ll start with a problem that is very simple (rated as “Easy” by LeetCode), but still helps us grasp the core problem solving techniques behind binary trees. We’ll also encounter some interesting curveballs that Rust can throw at us when it comes to building more complex data structures.
The Problem
Our problem today is Invert Binary Tree. Given a binary tree, we want to return a new tree that is the mirror image of the input tree. For example, if we get this tree as an input:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
We should output a tree that looks like this:
- 45
/ \
50 32
/ / \
100 40 5
/ \
43 37
We see that 45 remains the root element, but instead of having 32 on the left and 50 on the right, these two elements are reversed on the next level. Then on the 3rd level, 40 and 5 remain children of 32, but they are also reversed from their prior orientations! This pattern continues all the way down on both sides.
The Algorithm
Binary trees (and in fact, all tree structures) lend themselves very well to recursive algorithms. We’ll use a very simple recursive algorithm here.
If the input node to our function is a “null” node, we will simply return “null” as the output. If the node has a value and children, then we’ll keep that value as the value for our output. However, we will recursively invert both of the child nodes.
Then we’ll take the inverted “left” child node and install it as the “right” child of our result. The inverted “right” child of the original input becomes the “left” child of the resulting node.
Haskell Solution
Haskell is a natural fit for this problem, since it relies so heavily on recursion. We start by defining a recursive TreeNode
type. The canonical way to do this is with a Nil
constructor as well as a recursive “value” constructor that actually holds the node’s value and refers to the left and right child. For this problem, we’ll just assume our tree holds Int
values, so we won’t parameterize it.
data TreeNode = Nil | Node Int TreeNode TreeNode
deriving (Show, Eq)
Now solving our problem is easy! We pattern match on the input TreeNode
. For our first case, we just return Nil
for Nil
.
invertTree :: TreeNode -> TreeNode
invertTree Nil = Nil
invertTree (Node x left right) = ...
For our second case, we use the same Int
value for the value of our result. Then we recursively call invertTree
on the right
child, but put this in the place of the left child for our new result node. Likewise, we recursively invert the left
child of our original and use this result for the right of our result.
invertTree :: TreeNode -> TreeNode
invertTree Nil = Nil
invertTree (Node x left right) = Node x (invertTree right) (invertTree left)
Very easy!
C++ Solution
In a non-functional language, it is still quite possible to solve this problem without recursion, but this is an occasion where we get very nice, clean code with recursion. As a rare treat, we’ll actually start with a C++ solution instead of jumping to Rust right away.
We would start by defining our TreeNode
with a struct
. Instead of relying on a separate Nil
constructor, we use raw pointers for all our tree nodes. This means they can all potentially be nullptr
.
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int v, TreeNode* l, TreeNode* r) : val(v), left(l), right(r) {};
};
TreeNode* invertTree(TreeNode* root) {
...
}
And our solution looks almost as easy as the Haskell solution:
TreeNode* invertTree(TreeNode* root) {
if (root == nullptr) {
return nullptr;
}
return new TreeNode(root->val, invertTree(root->right), invertTree(root->left));
}
Rust Solution
In Rust, it’s not quite as easy to work with recursive structures because of Rust’s memory system. In C++, we used raw pointers, which is fast but can cause significant problems if you aren’t careful (e.g. dereferencing null pointers, or memory leaks). Haskell uses garbage collected memory, which is slow but allows us to write simple code that won’t blow up in weird ways like C++.
Rust’s Memory System
Rust seeks to be fast like C++, while making it hard to do high-risk things like de-referencing a potentially null pointer, or leaking memory. It does this using the concept of “ownership”, and it’s a tricky concept to understand at first.
The ownership model makes it a bit harder for us to write a basic recursive data structures. To write a basic binary tree, you’d have to answer questions like:
- Who “owns” the child nodes?
- Can I write a function that accesses the child nodes without taking ownership of them? What if I have to modify them?
- Can I copy a reference to a child node without copying the entire sub-structure?
- Can I create a “new” tree that references part of another tree without copying?
Writing a TreeNode
Here’s the TreeNode
struct provided by LeetCode for solving this problem. We can see that references to the nodes themselves are held within 3(!) wrapper types:
#[derive(Debug, PartialEq, Eq)]
pub struct TreeNode {
pub val: i32,
pub left: Option<Rc<RefCell<TreeNode>>>,
pub right: Option<Rc<RefCell<TreeNode>>>,
}
impl TreeNode {
#[inline]
pub fn new(val: i32) -> Self {
TreeNode {
val,
left: None,
right: None
}
}
}
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
}
From inside to outside, here’s what the three wrappers mean:
RefCell
is a mutable, shareable container for data.Rc
is a reference counting container. It automatically tracks how many references there are to theRefCell
. The cell is de-allocated once this count is 0.Option
is Rust’s equivalent ofMaybe
. This let’s us useNone
for an empty tree.
Rust normally only permits a single mutable reference, or multiple immutable references. So RefCell
provides mechanics to get multiple mutable references. Let’s see how we can use these to write our invert_tree
function.
Solving the Problem
We start by “cloning” the root
input reference. Normally, “clone” means a deep copy, but in our case, this doesn’t actually copy the entire tree! Because it is wrapped in Rc
, we’re just getting a new reference to the data in RefCell
. We conditionally check if this is a Some
wrapper. If it is None
, we just return the root
.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
...
}
return root;
}
If we didn’t “clone” root, the compiler would complain that we are “moving” the value in the condition, which would invalidate the prior reference to root
.
Next, we use borrow_mut
to get a mutable reference to the TreeNode
inside the RefCell
. This node_ref
finally gives us something of type TreeNode
so that we can work with the individual fields.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
...
}
return root;
}
Now for node_ref
, both left
and right
have the full wrapper type Option<Rc<RefCell<TreeNode>>>
. We want to recursively call invert_tree
on these. Once again though, we have to call clone
before passing these to the recursive function.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
// Recursively invert left and right subtrees
let left = invert_tree(node_ref.left.clone());
let right = invert_tree(node_ref.right.clone());
...
}
return root;
}
Now because we have a mutable reference in node_ref
, we can install these new results as its left
and right
subtrees!
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
// Recursively invert left and right subtrees
let left = invert_tree(node_ref.left.clone());
let right = invert_tree(node_ref.right.clone());
// Swap them
node_ref.left = right;
node_ref.right = left;
}
return root;
}
And now we’re done! We don’t need a separate return
statement inside the if
. We have modified node_ref
, which is still a reference to the same data as root
holds. So returning root
returns our modified tree.
Conclusion
Even though this was a simple problem with a basic recursive algorithm, we saw how Rust presented some interesting difficulties in applying this algorithm. Languages all make different tradeoffs, so every language has some example where it is difficult to write code that is simple in other languages. For Rust, this is recursive data structures. For Haskell though, it’s things like mutable arrays.
If you want to get some serious practice with binary trees, you should sign up for our problem solving course, Solve.hs. In Module 2, you’ll actually get to implement a balanced tree set from scratch, which is a very interesting and challenging problem that will stretch your knowledge!
Spiral Matrix: Another Matrix Layer Problem
In last week’s article, we learned how to rotate a 2D Matrix in place using Haskell’s mutable array mechanics. This taught us how to think about a Matrix in terms of layers, starting from the outside and moving in towards the center.
Today, we’ll study one more 2D Matrix problem that uses this layer-by-layer paradigm. For more practice dealing with multi-dimensional arrays, check out our Solve.hs course! In Module 2, you’ll study all kinds of different data structures in Haskell, including 2D Matrices (both mutable and immutable).
The Problem
Today’s problem is Spiral Matrix. In this problem, we receive a 2D Matrix, and we would like to return the elements of that matrix in a 1D list in “spiral order”. This ordering consists of starting from the top left and going right. When we hit the top right corner, we move down to the bottom. The we come back across the bottom row to the left, and then back up the top left. Then we continue this process on inner layers.
So, for example, let’s suppose we have this 4x4 matrix:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
This should return the following list:
[1,2,3,4,8,12,16,15,14,13,9,5,6,7,11,10]
At first glance, it seems like a lot of our layer-by-layer mechanics from last week will work again. All the numbers in the “first” layer come first, followed by the “second” layer, and so on. The trick though is that for this problem, we have to handle non-square matrices. So we can also have this matrix:
1 2 3 4
5 6 7 8
9 10 11 12
This should yield the list [1,2,3,4,8,12,11,10,9,5,6,7]
. This isn’t a huge challenge, but we need a slightly different approach.
The Algorithm
We still want to generally move through the Matrix using a layer-by-layer approach. But instead of tracking the 4 corner points, we’ll just keep track of 4 “barriers”, imaginary lines dictating the “end” of each dimension (up/down/left/right) for us to scan. These barriers will be inclusive, meaning that they refer to the last valid row or column in that direction. We would call these “min row”, “min column”, “max row” and “max column”.
Now the general process for going through a layer will consist of 4 steps. Each step starts in a corner location and proceeds in one direction until the next corner is reached. Then, we can start again with the next layer.
The trick is the end condition. Because we can have rectangular matrices, the final layer can have a shape like 1 x n
or n x 1
, and this is a problem, because we wouldn’t need 4 steps. Even a square matrix of n x n
with odd n
would have a 1x1 as its final layer, and this is also a problem since it is unclear which “corner” this coordinate
Thus we have to handle these edge cases. However, they are easy to both detect and resolve. We know we are in such a case when “min row” and “max row” are equal, or if “min column” and “max column” are equal. Then to resolve the case, we just do one pass instead of 4, including both endpoints.
Rust Solution
For our Rust solution, let’s start by defining important terms, like we always do. For our terms, we’ll mainly be dealing with these 4 “barrier” values, the min and max for the current row and column. These are inclusive, so they are initially 0 and (length - 1). We also make a new vector to hold our result values.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
let mut result: Vec<i32> = Vec::new();
let mut minR: usize = 0;
let mut maxR: usize = matrix.len() - 1;
let mut minC: usize = 0;
let mut maxC: usize = matrix[0].len() - 1;
...
}
Now we want to write a while
loop where each iteration processes a single layer. We’ll know we are out of layers if either “minimum” exceeds its corresponding “maximum”. Then we can start penciling in the different cases and phases of the loop. The edge cases occur when a minimum is exactly equal to its maximum. And for the normal case, we’ll do our 4-directional scanning.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
let mut result: Vec<i32> = Vec::new();
let mut minR: usize = 0;
let mut maxR: usize = matrix.len() - 1;
let mut minC: usize = 0;
let mut maxC: usize = matrix[0].len() - 1;
while (minR <= maxR && minC <= maxC) {
// Edge cases: single row or single column layers
if (minR == maxR) {
...
break;
} else if (minC == maxC) {
...
break;
}
// Scan TL->TR
...
// Scan TR->BR
...
// Scan BR->BL
...
// Scan BL->TL
...
minR += 1;
minC += 1;
maxR -= 1;
maxC -= 1;
}
return result;
}
Our “loop update” step comes at the end, when we increase both minimums, and decrease both maximums. This shows we are shrinking to the next layer.
Now we just have to fill in each case. All of these are scans through some portion of the matrix. The only trick is getting the ranges correct for each scan.
We’ll start with the edge cases. For a single row or column scan, we just need one loop. This loop should be inclusive across its dimension. Rust has a similar range syntax to Haskell, but it is less flexible. We can make a range inclusive by using =
before the end element.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
...
while (minR <= maxR && minC <= maxC) {
// Edge cases: single row or single column layers
if (minR == maxR) {
for i in minC..=maxC {
result.push(matrix[minR][i]);
}
break;
} else if (minC == maxC) {
for i in minR..=maxR {
result.push(matrix[i][minC]);
}
break;
}
...
}
return result;
}
Now let’s fill in the other cases. Again, getting the right ranges is the most important factor. We also have to make sure we don’t mix up our dimensions or directions! We go right along minR
, down along maxC
, left along maxR
, and then up along minC
.
To represent a decreasing range, we have to make the corresponding incrementing range and then use .rev()
to reverse it. This is a little inconvenient, giving up ranges that don’t look as nice, like for i in ((minC+1)..=maxC).rev()
, because we want the decrementing range to include maxC
but exclude minC
.
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
...
while (minR <= maxR && minC <= maxC) {
...
// Scan TL->TR
for i in minC..maxC {
result.push(matrix[minR][i]);
}
// Scan TR->BR
for i in minR..maxR {
result.push(matrix[i][maxC]);
}
// Scan BR->BL
for i in ((minC+1)..=maxC).rev() {
result.push(matrix[maxR][i]);
}
// Scan BL->TL
for i in ((minR+1)..=maxR).rev() {
result.push(matrix[i][minC]);
}
minR += 1;
minC += 1;
maxR -= 1;
maxC -= 1;
}
return result;
}
But once these cases are filled in, we’re done! Here’s the full solution:
pub fn spiral_order(matrix: Vec<Vec<i32>>) -> Vec<i32> {
let mut result: Vec<i32> = Vec::new();
let mut minR: usize = 0;
let mut maxR: usize = matrix.len() - 1;
let mut minC: usize = 0;
let mut maxC: usize = matrix[0].len() - 1;
while (minR <= maxR && minC <= maxC) {
// Edge cases: single row or single column layers
if (minR == maxR) {
for i in minC..=maxC {
result.push(matrix[minR][i]);
}
break;
} else if (minC == maxC) {
for i in minR..=maxR {
result.push(matrix[i][minC]);
}
break;
}
// Scan TL->TR
for i in minC..maxC {
result.push(matrix[minR][i]);
}
// Scan TR->BR
for i in minR..maxR {
result.push(matrix[i][maxC]);
}
// Scan BR->BL
for i in ((minC+1)..=maxC).rev() {
result.push(matrix[maxR][i]);
}
// Scan BL->TL
for i in ((minR+1)..=maxR).rev() {
result.push(matrix[i][minC]);
}
minR += 1;
minC += 1;
maxR -= 1;
maxC -= 1;
}
return result;
}
Haskell Solution
Now let’s write our Haskell solution. We don’t need any fancy mutation tricks here. Our function will just take a 2D array, and return a list of numbers.
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix = ...
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
Since we used a while loop in our Rust solution, it makes sense that we’ll want to use a raw recursive function that we’ll just call f
. Our loop state was the 4 “barrier” values in each dimensions. We’ll also use an accumulator value for our result. Since our barriers are inclusive, we can simply use the bounds of our array for the initial values.
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f = undefined
This recursive function has 3 base cases. First, we have the “loop condition” we used in our Rust solution. If a min dimension value exceeds the max, we are done, and should return our accumulated result list.
Then the other two cases are our edge cases or having a single row or a single column for our final layer. In all these cases, we want to reverse the accumulated list. This means that when we put together our ranges, we want to be careful that they are in reverse order! So the edge cases should start at their max
value and decrease to the min
value (inclusive).
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix arr = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f minR minC maxR maxC acc
| minR > maxR || minC > maxC = reverse acc
| minR == maxR = reverse $ [arr A.! (minR, c) | c <- [maxC,maxC - 1..minC]] <> acc
| minC == maxC = reverse $ [arr A.! (r, minC) | r <- [maxR,maxR - 1..minR]] <> acc
| otherwise = ...
Now to fill in the otherwise
case, we can do our 4 steps: going right from the top left, then going down from the top right, going left from the bottom right, and going up from the bottom left.
Like the edge cases, we make list comprehensions with ranges to pull the new numbers out of our input matrix. And again, we have to make sure we accumulate them in reverse order. Then we append all of them to the existing accumulation.
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix arr = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f minR minC maxR maxC acc
...
| otherwise =
let goRights = [arr A.! (minR, c) | c <- [maxC - 1, maxC - 2..minC]]
goDowns = [arr A.! (r, maxC) | r <- [maxR - 1, maxR - 2..minR]]
goLefts = [arr A.! (maxR, c) | c <- [minC + 1..maxC]]
goUps = [arr A.! (r, minC) | r <- [minR+1..maxR]]
acc' = goUps <> goLefts <> goDowns <> goRights <> acc
in f (minR + 1) (minC + 1) (maxR - 1) (maxC - 1) acc'
We conclude by making our recursive call with the updated result list, and shifting the barriers to get to the next layer.
Here’s the full implementation:
spiralMatrix :: A.Array (Int, Int) Int -> [Int]
spiralMatrix arr = f minR' minC' maxR' maxC' []
where
((minR', minC'), (maxR', maxC')) = A.bounds arr
f :: Int -> Int -> Int -> Int -> [Int] -> [Int]
f minR minC maxR maxC acc
| minR > maxR || minC > maxC = reverse acc
| minR == maxR = reverse $ [arr A.! (minR, c) | c <- [maxC,maxC - 1..minC]] <> acc
| minC == maxC = reverse $ [arr A.! (r, minC) | r <- [maxR,maxR - 1..minR]] <> acc
| otherwise =
let goRights = [arr A.! (minR, c) | c <- [maxC - 1, maxC - 2..minC]]
goDowns = [arr A.! (r, maxC) | r <- [maxR - 1, maxR - 2..minR]]
goLefts = [arr A.! (maxR, c) | c <- [minC + 1..maxC]]
goUps = [arr A.! (r, minC) | r <- [minR+1..maxR]]
acc' = goUps <> goLefts <> goDowns <> goRights <> acc
in f (minR + 1) (minC + 1) (maxR - 1) (maxC - 1) acc'
Conclusion
This is the last matrix-based problem we’ll study for now. Next time we’ll start considering some tree-based problems. If you sign up for our Solve.hs course, you’ll learn about both of these kinds of data structures in Module 2. You’ll implement a tree set from scratch, and you’ll get lots of practice working with these and many other structures. So enroll today!
Image Rotation: Mutable Arrays in Haskell
In last week’s article, we took our first step into working with multi-dimensional arrays. Today, we’ll be working with another Matrix problem that involves in-place mutation. The Haskell solution uses the MArray
interface, which takes us out of our usual
The MArray
interface is a little tricky to work with. If you want a full overview of the API, you should sign up for our Solve.hs course, where we cover mutable arrays in module 2!
The Problem
Today’s problem is Rotate Image. We’re going to take a 2D Matrix of integer values as our input and rotate the matrix 90 degrees clockwise. We must accomplish this in place, modifying the input value without allocating a new Matrix. The input matrix is always “square” (n x n).
Here are a few examples to illustrate the idea. We can start with a 2x2 matrix:
1 2 | 3 1
3 4 | 4 2
The 4x4 rotation makes it more clear that we’re not just moving numbers one space over. Each corner element will go to a new corner. You can also see how the inside of the matrix is also rotating:
1 2 3 4 | 13 9 5 1
5 6 7 8 | 14 10 6 2
9 10 11 12 | 15 11 7 3
13 14 15 16 | 16 12 8 4
The 3x3 version shows how with an odd number of rows and columns, the inner most number will stand still.
1 2 3 | 7 4 1
4 5 6 | 8 5 2
7 8 9 | 9 6 3
The Algorithm
While this problem might be a little intimidating at first, we just have to break it into sufficiently small and repeatable pieces. The core step is that we swap four numbers into each other’s positions. It’s easy to see, for example, that the four corners always trade places with one another (1, 4, 13, 16 in the 4x4 example).
What’s important is seeing the other sets of 4. We move clockwise to get the next 4 values:
- The value to the right of the top left corner
- The value below the top right corner
- The value to the left of the bottom right corner
- The value above the bottom left corner.
So in the 4x4 example, these would be 2, 8, 15, 9. Then another group is 3, 12, 14, 15.
Those 3 groups are all the rotations we need for the “outer layer”. Then we move to the next layer, where we have a single group of 4: 6, 7, 10, 11.
This should tell us that we have a 3-step process:
- Loop through each layer of the matrix
- Identify all groups of 4 in this layer
- Rotate each group of 4
It helps to put a count on the size of each of these loops. For an n x n matrix, the number of layers to rotate is n / 2
, rounded down, because the inner-most layer needs no rotation in an odd-sized matrix.
Then for a layer spanning from column c1
to c2
, the number of groups in that layer is just c2 - c1
. So for the first layer in a 4x4, we span columns 0 to 3, and there are 3 groups of 4. In the inner layer, we span columns 1 to 2, so there is only 1 group of 4.
Rust Solution
As is typical, we’ll see more of a loop structure in our Rust code, and a recursive version of this solution in Haskell. We’ll also start by defining various terms we’ll use. There are multiple ways to approach the details of this problem, but we’ll take an approach that maximizes the clarity of our inner loops.
We’ll define each “layer” using the four corner coordinates of that layer. So for an n x n matrix, these are (0,0), (0, n - 1), (n - 1, n - 1), (n - 1, 0)
. After we finish looping through a layer, we can simply increment/decrement each of these values as appropriate to get the corner coordinates of the next layer ((1,1), (1, n - 2)
, etc.).
So let’s start our solution by defining the 8 mutable values for these 4 corners. Each corner (top/left/bottom/right) has a row R
and column C
value.
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
let n = matrix.len();
let numLayers = n / 2;
let mut topLeftR = 0;
let mut topLeftC = 0;
let mut topRightR = 0;
let mut topRightC = n - 1;
let mut bottomRightR = n - 1;
let mut bottomRightC = n - 1;
let mut bottomLeftR = n - 1;
let mut bottomLeftC = 0;
...
}
It would be possible to solve the problem without these values, determining coordinates using the layer number. But I’ve found this to be somewhat more error prone, since we’re constantly adding and subtracting from different coordinates in different combinations. We get the number of layers from n / 2
.
Now let’s frame the outer loop. We conclude the loop by modifying each coordinate point. Then at the beginning of the loop, we can determine the number of “groups” for the layer by taking the difference between the left and right column coordinates.
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
...
for i in 0..numLayers {
let numGroups = topRightC - topLeftC;
for j in 0..numGroups {
...
}
topLeftR += 1;
topLeftC += 1;
topRightR += 1;
topRightC -= 1;
bottomRightR -= 1;
bottomRightC -= 1;
bottomLeftR -= 1;
bottomLeftC += 1;
}
}
Now we just need the logic for rotating a single group of 4 points. This is a 5-step process:
- Save top left value as
temp
- Move bottom left to top left
- Move bottom right to bottom left
- Move top right to bottom right
- Move
temp
(original top left) to top right
Unlike the layer number, we’ll use the group variable j
for arithmetic here. When you’re writing this yourself, it’s important to go slowly to make sure you’re using the right corner values and adding/subtracting j
from the correct dimension.
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
...
for i in 0..numLayers {
let numGroups = topRightC - topLeftC;
for j in 0..numGroups {
let temp = matrix[topLeftR][topLeftC + j];
matrix[topLeftR][topLeftC + j] = matrix[bottomLeftR - j][bottomLeftC];
matrix[bottomLeftR - j][bottomLeftC] = matrix[bottomRightR][bottomRightC - j];
matrix[bottomRightR][bottomRightC - j] = matrix[topRightR + j][topRightC];
matrix[topRightR + j][topRightC] = temp;
}
... // (update corners)
}
}
And then we’re done! We don’t actually need to return a value since we’re just modifying the input in place. Here’s the full solution:
pub fn rotate(matrix: &mut Vec<Vec<i32>>) {
let n = matrix.len();
let numLayers = n / 2;
let mut topLeftR = 0;
let mut topLeftC = 0;
let mut topRightR = 0;
let mut topRightC = n - 1;
let mut bottomRightR = n - 1;
let mut bottomRightC = n - 1;
let mut bottomLeftR = n - 1;
let mut bottomLeftC = 0;
for i in 0..numLayers {
let numGroups = topRightC - topLeftC;
for j in 0..numGroups {
let temp = matrix[topLeftR][topLeftC + j];
matrix[topLeftR][topLeftC + j] = matrix[bottomLeftR - j][bottomLeftC];
matrix[bottomLeftR - j][bottomLeftC] = matrix[bottomRightR][bottomRightC - j];
matrix[bottomRightR][bottomRightC - j] = matrix[topRightR + j][topRightC];
matrix[topRightR + j][topRightC] = temp;
}
topLeftR += 1;
topLeftC += 1;
topRightR += 1;
topRightC -= 1;
bottomRightR -= 1;
bottomRightC -= 1;
bottomLeftR -= 1;
bottomLeftC += 1;
}
}
Haskell Solution
This is an interesting problem to solve in Haskell because Haskell is a generally immutable language. Unlike Rust, we can’t make values mutable just by putting the keyword mut
in front of them.
With arrays, we can modify them in place though using the MArray
monad class. We won’t go through all the details of the interface in this article (you can learn about all that in Solve.hs Module 2). But we’ll start with the type signature:
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
This tells us we are taking a mutable array, where the array
type is polymorphic but tied to the monad m
. For example, IOArray
would work with the IO
monad. We don’t return anything, because we’re modifying our input.
We still begin our function by defining terms, but now we need to use monadic actions to retrieve even the bounds our our array.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
...
Our algorithm has two loop levels. The outer loop goes through the different layers of the matrix. The inner layer goes through each group of 4 within the layer. In Haskell, both of these loops are recursive, monadic functions. Our Rust loops treat the four corner points of the layer as stateful values, so these need to be inputs to our recursive functions. In addition, each function will take the layer/group number as an input.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
...
where
rotateLayer tl@(tlR, tlC) tr@(trR, trC) br@(brR, brC) bl@(blR, blC) n = ...
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = ...
Now we just have to fill in these functions. For rotateLayer
, we use the “layer number” parameter as a countdown. Once it reaches 0, we’ll be done. We just need to determine the number of groups in this layer using the column difference of left and right. Then we’ll call rotateGroup
for each group.
We make the first call to rotateLayer
with numLayers
and the original corners, coming from our dimensions. When we recurse, we add/subtract 1 from the corner dimensions, and subtract 1 from the layer number.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
rotateLayer (minR, minC) (minR, maxC) (maxR, maxC) (maxR, minC) numLayers
where
rotateLayer _ _ _ _ 0 = return ()
rotateLayer tl@(tlR, tlC) tr@(trR, trC) br@(brR, brC) bl@(blR, blC) n = do
let numGroups = ([0..(trC - tlC - 1)] :: [Int])
forM_ numGroups (rotateGroup tl tr br bl)
rotateLayer (tlR + 1, tlC + 1) (trR + 1, trC - 1) (brR - 1, brC - 1) (blR - 1, blC + 1) (n - 1)
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = ...
And how do we rotate a group? We use the same five steps we took in Rust. We save the top left as temp
and then move the values around. We use the monadic functions readArray
and writeArray
to perform these actions in place on our Matrix.
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
...
where
...
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = do
temp <- readArray arr (tlR, tlC + j)
readArray arr (blR - j, blC) >>= writeArray arr (tlR, tlC + j)
readArray arr (brR, brC - j) >>= writeArray arr (blR - j, blC)
readArray arr (trR + j, trC) >>= writeArray arr (brR, brC - j)
writeArray arr (trR + j, trC) temp
Here’s the full implementation:
rotateImage :: (MArray array Int m) => array (Int, Int) Int -> m ()
rotateImage arr = do
((minR, minC), (maxR, maxC)) <- getBounds arr
let n = maxR - minR + 1
let numLayers = n `quot` 2
rotateLayer (minR, minC) (minR, maxC) (maxR, maxC) (maxR, minC) numLayers
where
rotateLayer _ _ _ _ 0 = return ()
rotateLayer tl@(tlR, tlC) tr@(trR, trC) br@(brR, brC) bl@(blR, blC) n = do
let numGroups = ([0..(trC - tlC - 1)] :: [Int])
forM_ numGroups (rotateGroup tl tr br bl)
rotateLayer (tlR + 1, tlC + 1) (trR + 1, trC - 1) (brR - 1, brC - 1) (blR - 1, blC + 1) (n - 1)
rotateGroup (tlR, tlC) (trR, trC) (brR, brC) (blR, blC) j = do
temp <- readArray arr (tlR, tlC + j)
readArray arr (blR - j, blC) >>= writeArray arr (tlR, tlC + j)
readArray arr (brR, brC - j) >>= writeArray arr (blR - j, blC)
readArray arr (trR + j, trC) >>= writeArray arr (brR, brC - j)
writeArray arr (trR + j, trC) temp
Conclusion
We’ve got one more Matrix problem to solve next time, and then we’ll move on to some other data structures. To learn more about using Data Structures and Algorithms in Haskell, you take our Solve.hs course. You’ll get the chance to write a number of data structures from scratch, and you’ll get plenty of practice working with them and using them in algorithms!
Binary Search in a 2D Matrix
In our problem last week, we covered a complex problem that used a binary search. Today, we’ll apply binary search again to solidify our understanding of it. This time, instead of extra algorithmic complexity, we’ll start adding some data structure complexity. We’ll be working with a 2D Matrix instead of basic arrays.
To learn more about data structures and algorithms in Haskell, you should take a look at our Solve.hs course! In particular, you’ll cover multi-dimensional arrays in module 2, and you’ll learn how to write algorithms in Haskell in module 3!
The Problem
Today’s problem is Search a 2D Matrix, and the description is straightforward. We’re given a 2D m x n matrix, as well as a target number. We have to return a boolean for whether or not that number is in the Matrix.
This is trivial with a simple scan, but we have an additional constraint that lets us solve the problem faster. The matrix is essentially ordered. Each row is non-decreasing, and the first element of each successive row is no smaller than the last element of the preceding row.
This allows us to get a solution that is O(log(n + m)), a considerable improvement over a linear scan.
The Algorithm
The algorithm is simple as well. We’ll do two binary searches. First, we’ll search over the rows to identify the last row which could contain the element. Then we’ll do a binary search of that row to see if the element is present or not.
We’ll have a slightly different form to our searches compared to last time. In last week’s problem, we knew we had to find a valid index for our search. Now, we may find that no valid index exists.
So we’ll structure our search interval in a semi-open fashion. The first index in our search interval is inclusive, meaning that it could still be a valid index. The second index is exclusive, meaning it is the lowest index that we consider invalid.
In mathematical notation, we would represent such an interval with a square bracket on the left and a parenthesis on the right. So if that interval is [0, 4)
, then 0, 1, 2, 3
are valid values. The interval [2,2)
would be considered empty, with no valid values. We’ll see how we apply this idea in practice.
Rust Solution
We don’t have that many terms to define at the start of this solution. We’ll save the size of both dimensions, and then prepare ourselves for the first binary search by assigning low
as 0 (the first potential “valid” answer), hi
as m
(the lowest “invalid” answer), and creating our output rowWithTarget
value. For this, we also assign m
, an invalid value. If we fail to re-assign rowWithTarget
in our binary search, we want it assigned to an easily testable invalid value.
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
let m = matrix.len();
let n = matrix[0].len();
let mut low = 0;
let mut hi = m;
let mut rowWithTarget = m;
...
}
Now we write our first binary search, looking for a row that could contain our target value. We maintain the typical pattern of binary search, using the loop while (low < hi)
and assigning mid = (low + hi) / 2
.
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
...
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[mid][0] > target) {
hi = mid;
} else if (matrix[mid][n - 1] < target) {
low = mid + 1;
} else {
rowWithTarget = mid;
break;
}
}
if (rowWithTarget >= m) {
return false;
}
...
}
If the first element of the row is too large, we know that mid
is “invalid”, so we can assign it as hi
and continue. If the last element is too small, then we reassign low
as mid + 1
, as we want low
to still be a potentially valid value.
Otherwise, we have found a potential row, so we assign rowWithTarget
and break. If, after this search, rowWithTarget
has the “invalid” value of m
, we can return false
, as there are no valid values.
Now we just do the same thing over again, but within rowWithTarget
! We reassign low
and hi
(as n
this time) to reset the while loop. And now our comparisons will look at the specific value matrix[rowWithTarget][mid]
.
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
...
low = 0;
hi = n;
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[rowWithTarget][mid] > target) {
hi = mid;
} else if (matrix[rowWithTarget][mid] < target) {
low = mid + 1;
} else {
return true;
}
}
return false;
}
Again, we follow the same pattern of re-assigning low
and hi
. If we don’t hit the return true
case in the loop, we’ll end up with return false
at the end, because we haven’t found the target.
Here’s the full solution:
pub fn search_matrix(matrix: Vec<Vec<i32>>, target: i32) -> bool {
let m = matrix.len();
let n = matrix[0].len();
let mut low = 0;
let mut hi = m;
let mut rowWithTarget = m;
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[mid][0] > target) {
hi = mid;
} else if (matrix[mid][n - 1] < target) {
low = mid + 1;
} else {
rowWithTarget = mid;
break;
}
}
if (rowWithTarget >= m) {
return false;
}
low = 0;
hi = n;
while (low < hi) {
let mid: usize = (low + hi) / 2;
if (matrix[rowWithTarget][mid] > target) {
hi = mid;
} else if (matrix[rowWithTarget][mid] < target) {
low = mid + 1;
} else {
return true;
}
}
return false;
}
Haskell Solution
In our Haskell solution, the main difference of course will be using recursion for the binary search. However, we’ll also change up the data structure a bit. In the Rust framing of the problem, we had a vector of vectors of values. We could do this in Haskell, but we could also use Array (Int, Int) Int
. This lets us map row/column pairs to numbers in a more intuitive way.
import qualified Data.Array as A
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = ...
where
((minR, minC), (maxR, maxC)) = A.bounds matrix
Another unique feature of arrays is that the bounds don’t have to start from 0. We can have totally custom bounding dimensions for our rows and columns. So instead of using m
and n
, we’ll need to use the min and max of the row and column dimensions.
So now let’s define our first binary search, looking for the valid row. As we did last week, the input to our function will be two Int
values, for the low
and hi
. As in our Rust solution we’ll access the first and last element of the row defined by the “middle” of low
and hi
, and compare them against the target. We make recursive calls to searchRow
if the row isn’t valid.
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = result
where
((minR, minC), (maxR, maxC)) = A.bounds matrix
searchRow :: (Int, Int) -> Int
searchRow (low, hi) = if low >= hi then maxR + 1 else
let mid = (low + hi) `quot` 2
firstInRow = matrix A.! (mid, minC)
lastInRow = matrix A.! (mid, maxC)
in if firstInRow > target
then searchRow (low, mid)
else if lastInRow < target
then searchRow (mid + 1, hi)
else mid
rowWithTarget = searchRow (minR, maxR + 1)
result = rowWithTarget <= maxR && ...
Instead of m
, we have maxR + 1
, which we use as the initial hi
value, as well as a return value in the base case where low
meets hi
. We can return a result of False
if rowWithTarget
does not come back with a value smaller than maxR
.
Now for our second search, we follow the same pattern, but now we’re returning a boolean. The base case returns False
, and we return True
if we find the value in rowWithTarget
at position mid
. Here’s what that looks like:
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = result
where
...
rowWithTarget = searchRow (minR, maxR + 1)
searchCol :: (Int, Int) -> Bool
searchCol (low, hi) = low < hi &&
let mid = (low + hi) `quot` 2
val = matrix A.! (rowWithTarget, mid)
in if val > target
then searchCol (low, mid)
else if val < target
then searchCol (mid + 1, hi)
else True
result = rowWithTarget <= maxR && searchCol (minC, maxC + 1)
You’ll see we now use the outcome of searchCol
for result
. And this completes our solution! Here’s the full code:
search2DMatrix :: A.Array (Int, Int) Int -> Int -> Bool
search2DMatrix matrix target = result
where
((minR, minC), (maxR, maxC)) = A.bounds matrix
searchRow :: (Int, Int) -> Int
searchRow (low, hi) = if low >= hi then maxR + 1 else
let mid = (low + hi) `quot` 2
firstInRow = matrix A.! (mid, minC)
lastInRow = matrix A.! (mid, maxC)
in if firstInRow > target
then searchRow (low, mid)
else if lastInRow < target
then searchRow (mid + 1, hi)
else mid
rowWithTarget = searchRow (minR, maxR + 1)
searchCol :: (Int, Int) -> Bool
searchCol (low, hi) = low < hi &&
let mid = (low + hi) `quot` 2
val = matrix A.! (rowWithTarget, mid)
in if val > target
then searchCol (low, mid)
else if val < target
then searchCol (mid + 1, hi)
else True
result = rowWithTarget <= maxR && searchCol (minC, maxC + 1)
Conclusion
Next week, we’ll stay on the subject of 2D matrices, but we’ll learn about array mutation. This is a very tricky subject in Haskell, so make sure to come back for that article!
To learn how these data structures work in Haskell, read about Solve.hs, our Haskell Data Structures & Algorithms course!
Binary Search in Haskell and Rust
This week we’ll be continuing our series of problem solving in Haskell and Rust. But now we’re going to start moving beyond the terrain of “basic” problem solving techniques with strings, lists and arrays, and start moving in the direction of more complicated data structures and algorithms. Today we’ll explore a problem that is still array-based, but uses a tricky algorithm that involves binary search!
You’ll learn more about Data Structures and Algorithms in our Solve.hs course! The last 7 weeks or so of blog articles have focused on the types of problems you’ll see in Module 1 of that course, but now we’re going to start encountering ideas from Modules 2 & 3, which look extensively at essential data structures and algorithms you need to know for problem solving.
The Problem
Today’s problem is median of two sorted arrays. In this problem, we receive two arrays of numbers as input, each of them in sorted order. The arrays are not necessarily of the same size. Our job is to find the median of the cumulative set of numbers.
Now there’s a conceptually easy approach to this. We could simply scan through the two arrays, keeping track of one index for each one. We would increase the index for whichever number is currently smaller, and stop once we have passed by half of the total numbers. This approach is essentially the “merge” part of merge sort, and it would take O(n) time, since we are scanning half of all the numbers.
However, there’s a faster approach! And if you are asked this question in an interview for anything other than a very junior position, your interviewer will expect you to find this faster approach. Because the arrays are sorted, we can leverage binary search to find the median in O(log n) time. The approach isn’t easy to see though! Let’s go over the algorithm before we get into any code.
The Algorithm
This algorithm is a little tricky to follow (this problem is rated as “hard” on LeetCode). So we’re going to treat this a bit like a mathematical proof, and begin by defining useful terms. Then it will be easy to describe the coding concepts behind the algorithm.
Defining our Terms
Our input consists of 2 arrays, arr1
and arr2
with potentially different sizes n
and m
, respectively. Without loss of generality, let arr1
be the “shorter” array, so that n <= m
. We’ll also define t
as the total number of elements, n + m
.
It is worthwhile to note right off the bat that if t
is odd, then a single element from one of the two lists will be the median. If t
is even, then we will average two elements together. Even though we won’t actually create the final merged array, we can imagine that it consists of 3 parts:
- The “prior” portion - all numbers before the median element(s)
- The median element(s), either 1 or 2.
- The “latter” portion - all numbers after the median element(s)
The total number of elements in the “prior” portion will end up being (t - 1) / 2
, bearing in mind how integer division works. For example, whether t
is 15 or 16, we get 7 elements in the “prior” portion. We’ll use p
for this number.
Finally, let’s imagine p1
, the number of elements from arr1
that will end up in the prior portion. If we know p1
, then p2
, the number of elements from arr2
in the prior portion is fixed, because p1 + p2 = p
. We can then think of p1
as an index into arr1
, the index of the first element that is not in the prior portion. The only trick is that this index could be n
indicating that all elements of arr1
are in the prior portion.
Getting the Final Answer from our Terms
If we have the “correct” values for p1
and p2
, then finding the median is easy. If t
is odd, then the lower number between arr1[p1]
and arr2[p2]
is the median. If t
is even, then we average the two smallest numbers among (arr1[p1], arr2[p2], arr1[p1 + 1], arr2[p2 + 1])
.
So we’ve reduced this problem to a matter of finding p1
, since p2
can be easily derived from it. How do we know we have the “correct” value for p1
, and how do we search for it efficiently?
Solving for p1
The answer is that we will conduct a binary search on arr1
in order to find the correct value of p1
. For any particular choice of p1
, we determine the corresponding value of p2
. Then we make two comparisons:
- Compare
arr1[p1 - 1]
toarr2[p2]
- Compare
arr2[p2 - 1]
toarr1[p1]
If both comparisons are less-than-or-equals, then our two p
values are correct! The slices arr1[0..p1-1]
and arr2[0..p2-1]
always constitute a total of p
values, and if these values are smaller than arr1[p1]
and arr2[p2]
, then they constitute the entire “prior” set.
If, on the other hand, the first comparison yields “greater than”, then we have too many values for arr1
in our prior set. This means we need to recursively do the binary search on the left side of arr1
, since p1
should be smaller.
Then if the second comparison yields “greater than”, we have too few values from arr1
in the “prior” set. We should increase p1
by searching the right half of our array.
This provides a complete algorithm for us to follow!
Rust Implementation
Our algorithm description was quite long, but the advantage of having so many details is that the code starts to write itself! We’ll start with our Rust implementation. Stage 1 is to define all of the terms using our input values. We want to define our sizes and array references generically so that arr1
is the shorter array:
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let mut n = nums1.len();
let mut m = nums2.len();
let mut arr1: &Vec<i32> = &nums1;
let mut arr2: &Vec<i32> = &nums2;
if (m < n) {
n = nums2.len();
m = nums1.len();
arr1 = &nums2;
arr2 = &nums1;
}
let t = n + m;
let p: usize = (t - 1) / 2;
...
}
Anatomy of a Binary Search
The next stage is the binary search, so we can find p1
and p2
. Now a binary search is a particular kind of loop pattern. Like many of the loop patterns we worked with in the previous weeks, we can express it recursively, or with a loop construct like for
or while
. We’ll start with a while
loop solution for Rust, and then show the recursive solution with Haskell.
All loops maintain some kind of state. For a binary search, the primary state is the two endpoints representing our “interval of interest”. This starts out as the entire interval, and shrinks by half each time until we’ve narrowed to a single element (or no elements). We’ll represent these with interval end points with low
and hi
. Our loop concludes once low
is as large as hi
.
let mut low = 0;
// Use the shorter array size!
let mut hi = n;
while (low < hi) {
...
}
In our particular case, we are also trying to determine the values for p1
and p2
. Each time we specify an interval, we’ll see if the midpoint of that interval (between low
and hi
) is the correct value of p1
:
...
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
...
}
Now we evaluate this p1
value using the two conditions we specified in our algorithm. These are self-explanatory, except we do need to cover some edge cases where one of our values is at the edge of the array bounds.
For example, if p1
is 0, the first condition is always “true”. If this condition is negated, this means we want fewer elements from arr1
, but this is impossible if p1
is 0.
...
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
if (cond1 && cond2) {
break;
} else if (!cond1) {
p1 -= 1;
hi = p1;
} else {
p1 += 1;
low = p1;
}
}
p2 = p - p1;
...
If both conditions are met, you’ll see we break
, because we’ve found the right value for p1
! Otherwise, we know p1
is invalid. This means we want to exclude the existing p1
value from further consideration by changing either low
or hi
to remove it from the interval of interest.
So if cond1
is false, hi
becomes p1 - 1
, and if cond2
is false, it becomes p1 + 1
. In both cases, we also modify p1
itself first so that our loop does not conclude with p1
in an invalid location.
Getting the Final Answer
Now that we have p1
and p2
, we have to do a couple final tricks to get the final answer. We want to get the first “smaller” value between arr1[p1]
and arr2[p2]
. But we have to handle the edge case where p1
might be n
AND we want to increment the index for the array we take. Note that p2
cannot be out of bounds right now!
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
If the total number of elements is odd, we can simply return this number (converting to a float). However, in the even case we need one more number to take an average. So we’ll compare the values at the indices again, but now accounting that either (but not both) could be out of bounds.
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
if (t % 2 == 0) {
if (p1 >= n) {
median += arr2[p2];
} else if (p2 >= m) {
median += arr1[p1];
} else {
median += cmp::min(arr1[p1], arr2[p2]);
}
let medianF: f64 = median.into();
return medianF / 2.0;
} else {
return median.into();
}
Here’s the complete implementation:
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let mut n = nums1.len();
let mut m = nums2.len();
let mut arr1: &Vec<i32> = &nums1;
let mut arr2: &Vec<i32> = &nums2;
if (m < n) {
n = nums2.len();
m = nums1.len();
arr1 = &nums2;
arr2 = &nums1;
}
let t = n + m;
let p: usize = (t - 1) / 2;
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
if (cond1 && cond2) {
break;
} else if (!cond1) {
p1 -= 1;
hi = p1;
} else {
p1 += 1;
low = p1;
}
}
p2 = p - p1;
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
if (t % 2 == 0) {
if (p1 >= n) {
median += arr2[p2];
} else if (p2 >= m) {
median += arr1[p1];
} else {
median += cmp::min(arr1[p1], arr2[p2]);
}
let medianF: f64 = median.into();
return medianF / 2.0;
} else {
return median.into();
}
}
Haskell Implementation
Now let’s examine the Haskell implementation. Unlike the LeetCode version, we’ll just assume our inputs are Double
already instead of doing a conversion. Once again, we start by defining the terms:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
n' = V.length input1
m' = V.length input2
t = n' + m'
p = (t - 1) `quot` 2
(n, m, arr1, arr2) = if V.length input1 <= V.length input2
then (n', m', input1, input2) else (m', n', input2, input1)
...
Now we’ll implement the binary search, this time doing a recursive function. We’ll do this in two parts, starting with a helper function. This helper function will simply tell us if a particular index is correct for p1
. The trick though is that we’ll return an Ordering
instead of just a Bool
:
-- data Ordering = LT | EQ | GT
f :: Int -> Ordering
This lets us signal 3 possibilities. If we return EQ
, this means the index is valid. If we return LT
, this will mean we want fewer values from arr1
. And then GT
means we want more values from arr1
.
With this framing it’s easy to see the implementation of this helper now. We determine the appropriate p2
, figure out our two conditions, and return the value for each condition:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
f :: Int -> Ordering
f pi1 =
let pi2 = p - pi1
cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
in if cond1 && cond2 then EQ else if (not cond1) then LT else GT
Now applying we can use this in a recursive binary search. The binary search tracks two pieces of state for our interval ((Int, Int)
), and it will return the correct value for p1
. The implementation applies the base case (return low
if low >= hi
), determines the midpoint, calls our helper, and then recurses appropriately based on the helper result.
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
f :: Int -> Ordering
f pi1 = ...
search :: (Int, Int) -> Int
search (low, hi) = if low >= hi then low else
let mid = (low + hi) `quot` 2
in case f mid of
EQ -> mid
LT -> search (low, mid - 1)
GT -> search (mid + 1, hi)
p1 = search (0, n)
p2 = p - p1
...
For the final part of the problem, we’ll define a helper. Given p1
and p2
, it will emit the “lower” value between the two indices in the array (accounting for edge cases) as well as the two new indices (since one will increment).
This is a matter of lazily defining the “next” value for each array, the “end” condition of each array, and the “result” if that array’s value is chosen:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
findNext pi1 pi2 =
let next1 = arr1 V.! pi1
next2 = arr2 V.! pi2
end1 = pi1 >= n
end2 = pi2 >= m
res1 = (next1, pi1 + 1, pi2)
res2 = (next2, pi1, pi2 + 1)
in if end1 then res2
else if end2 then res1
else if next1 <= next2 then res1 else res2
Now we just apply this either once or twice to get our result!
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
where
...
tIsEven = even t
(median1, nextP1, nextP2) = findNext p1 p2
(median2, _, _) = findNext nextP1 nextP2
result = if tIsEven
then (median1 + median2) / 2.0
else median1
Here’s the complete implementation:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
where
n' = V.length input1
m' = V.length input2
t = n' + m'
p = (t - 1) `quot` 2
(n, m, arr1, arr2) = if V.length input1 <= V.length input2
then (n', m', input1, input2) else (m', n', input2, input1)
-- Evaluate the index in arr1
-- If this does in indicate the index can be part of a median, return EQ
-- If it indicates we need to move left in shortArr, return LT
-- If it indicates we need to move right in shortArr, return GT
-- Precondition: p1 <= n
f :: Int -> Ordering
f pi1 =
let pi2 = p - pi1
cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
in if cond1 && cond2 then EQ else if (not cond1) then LT else GT
search :: (Int, Int) -> Int
search (low, hi) = if low >= hi then low else
let mid = (low + hi) `quot` 2
in case f mid of
EQ -> mid
LT -> search (low, mid - 1)
GT -> search (mid + 1, hi)
findNext pi1 pi2 =
let next1 = arr1 V.! pi1
next2 = arr2 V.! pi2
end1 = pi1 >= n
end2 = pi2 >= m
res1 = (next1, pi1 + 1, pi2)
res2 = (next2, pi1, pi2 + 1)
in if end1 then res2
else if end2 then res1
else if next1 <= next2 then res1 else res2
p1 = search (0, n)
p2 = p - p1
tIsEven = even t
(median1, nextP1, nextP2) = findNext p1 p2
(median2, _, _) = findNext nextP1 nextP2
result = if tIsEven
then (median1 + median2) / 2.0
else median1
Conclusion
If you want to learn more about these kinds of problem solving techniques, you should take our course Solve.hs! In the coming weeks, we’ll see more problems related to data structures and algorithms, which are covered extensively in Modules 2 and 3 of that course!
Buffer & Save with a Challenging Example
Welcome back to our series comparing LeetCode problems in Haskell and Rust. Today we’ll learn a new paradigm that I call “Buffer and Save”. This will also be the hardest problem we’ve done so far! The core loop structure isn’t that hard, but there are a couple layers of tricks to massage our data to get the final answer.
This will be the last problem we do that focuses strictly on string and list manipulation. The next set of problems we do will all rely on more advanced data structures or algorithmic ideas.
For more complete practice on problem solving in Haskell, check out Solve.hs, our newest course. This course will teach you everything you need to know about problem solving, data structures, and algorithms in Haskell. You’ll get loads of practice building structures and algorithms from scratch, which is very important for understanding and remembering how they work.
The Problem
Today’s problem is Text Justification. The idea here is that we are taking a list of words and a “maximum width” and printing out the words grouped into equal-width lines that are evenly spaced. Here’s an example input and output:
Example Input (list of 9 strings):
[“Study”, “Haskell”, “with”, “us”, “every”, “Monday”, “Morning”, “for”, “fun”]
Max Width: 16
Output (list of 4 strings):
“Study Haskell”
“with us every”
“Monday Morning”
“for fun ”
There are a few notable rules, constraints, and edge cases. Here’s a list to sumarize them:
- There is at least one word
- No word is larger than the max width
- All output strings must have max width as their length (including spaces)
- The first word of every line is set to the left
- The last line always has 1 space between words, and then enough spaces after the last word to read the max width.
- All other lines with multiple words will align the final word all the way to the right
- The spaces in non-final lines are distributed as evenly as possible, but extra spaces go between words to the left.
The final point is potentially the trickiest to understand. Consider the second line above, with us every
. The max width is 16, and we have 3 words with a total of 11 characters. This leaves us 5 spaces. Having 3 words means 2 blanks, so the “left” blank gets 3 spaces and the “right” blank gets 2 spaces.
If you had a line with 5 words, a max width of 30, and 16 characters, you would place 4 spaces in the left two blanks, and 3 spaces in the right two blanks. The relative length of the words does not matter.
Words in Line: [“A”, “good”, “day”, “to”, “endure”]
Output Line:
“A good day to endure”
The Algorithm
As mentioned above, our main algorithmic idea could be called “buffer and save”. We’ve been defining all of our loops based on the state we must maintain between iterations of the loop. The buffer and save approach highlights two pieces of state for us:
- The strings we’ve accumulated for our answer so far (the “result”)
- A buffer of the strings in the “current” line we’re building.
So we’ll loop through the input words one at a time. We’ll consider if the next word can be added to the “current” line. If it would cause our current line to exceed the maximum width, we’ll “save” our current line and write it out to the “result” list, adding the required spaces.
To help our calculations, we’ll also include two other pieces of state in our loop:
- The number of characters in our “current” line
- The number of words in our “current” line
Finally, there’s the question of how to construct each output line. Combining the math with list-mechanics is a little tricky. But the central idea consists of 4 simple steps:
- Find the number of spaces (subtract number of characters from max width)
- Divide the number of spaces by the number of “blanks” (number of words - 1)
- The quotient is the “base” number of spaces per blank
- The remainder is the number of blanks (starting from the left) that get an extra space
The exact implementation of this idea differs between Haskell and Rust. Again this rests a lot on the “reverse” differences between Rust vectors and Haskell lists.
The final line has a slightly different (but easier) process. And we should note that the final line will still be in our buffer when we exit the loop! So we shouldn’t forget to add it to the result.
Haskell Solution
We know enough now to jump into our Haskell solution. Our solution should be organized around a loop. Since we go through the input word-by-word, this should follow a fold pattern. So here’s our outline:
justifyText :: [String] -> Int -> [String]
justifyText inputWords maxWidth = ...
where
-- f = ‘final’
(fLine, fWordsInLine, fCharsInLine, result) = foldl loop ([], 0, 0, []) inputWords
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord = ...
Let’s focus in on the choice we have to make in the loop. We need to determine if this new word fits in our current line. So we’ll get its length and add it to the number of characters in the line AND consider the number of words in the line. We count the words too since each word we already have requires at least one space!
-- (maxWidth is still in scope here)
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ...
else ...
How do we fill in these choices? If we don’t overflow the line, we just append the new word, bump the count of the words, and add the new word’s length to the character count.
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ...
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
The overflow case isn’t hard, but it does require us to have a function that can convert our current line into the final string. This function will also take the number of words and characters in this line. Assuming this function exists, we just make this new line, append it to result
, and then reset our other stateful values so that they only reflect the “new word” as part of our current line.
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
resultLine = makeLine currentLine wordsInLine charsInLine
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ([newWord], 1, newWordLen, resultLine : currResult)
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
makeLine :: String -> Int -> Int -> String
makeLine = ...
Before we think about the makeLine
implementation though, we just about have enough to fill in the rest of the “top” of our function definition. We’d just need another function for making the “final” line, since this is different from other lines. Then when we get our “final” state values, we’ll plug them into this function to get our final line, append this to the result, and reverse it all.
justifyText :: [String] -> Int -> [String]
justifyText inputWords maxWidth =
reverse (makeLineFinal flLine fWordsInLine fCharsInLine : result)
where
(fLine, fWordsInLine, fCharsInLine, result) = foldl loop ([], 0, 0, []) inputWords
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
resultLine = makeLine currentLine wordsInLine charsInLine
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ([newWord], 1, newWordLen, resultLine : currResult)
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
makeLine :: [String] -> Int -> Int -> String
makeLine = ...
makeLineFinal :: [String] -> Int -> Int -> String
makeLineFinal = ...
Now let’s discuss forming these lines, starting with the general case. We can start with a couple edge cases. This should never be called with an empty list. And with a singleton, we just left-align the word and add the right number of spaces:
makeLine :: [String] -> Int -> Int -> String
makeLine [] _ _ = error "Cannot makeLine with empty string!"
makeLine [onlyWord] _ charsInLine =
let extraSpaces = replicate (maxWidth - charsInLine) ' '
in onlyWord <> extraSpaces
makeLine (first : rest) wordsInLine charsInLine = ...
Now we’ll calculate the quotient and remainder to get the spacing sizes, as mentioned in our algorithm section. But how do we combine them? There are multiple ways, but the idea I thought of was to zip
the tail
of the list with the number of spaces it needs to append. Then we can fold it into a resulting list using a function like this:
-- (String, Int) is the next string and the number of spaces after it
combine :: String -> (String, Int) -> String
combine suffix (nextWord, numSpaces) =
nextWord <> replicate numSpaces ' ' <> suffix
Remember while doing this that we’ve accumulated the words for each line in reverse order. So we want to append each one in succession, together with the number of spaces that come after it.
To use this function, we can “fold” over the “tail” of our current line, while using the first word in our list as the base of the fold! Don’t forget the quotRem
math going on in here!
makeLine :: [String] -> Int -> Int -> String
makeLine [] _ _ = error "Cannot makeLine with empty string!"
makeLine [onlyWord] _ charsInLine =
let extraSpaces = replicate (maxWidth - charsInLine) ' '
in onlyWord <> extraSpaces
makeLine (first : rest) wordsInLine charsInLine = ...
let (baseNumSpaces, numWithExtraSpace) = quotRem (maxWidth - charsInLine) (wordsInLine - 1)
baseSpaces = replicate (wordsInLine - 1 - numWithExtraSpace) baseNumSpaces
extraSpaces = replicate numWithExtraSpace (baseNumSpaces + 1)
wordsWithSpaces = zip rest (baseSpaces <> extraSpaces)
in foldl combine first wordsWithSpaces
combine :: String -> (String, Int) -> String
combine suffix (nextWord, numSpaces) =
nextWord <> replicate numSpaces ' ' <> suffix
To make the final line, we can also leverage our combine
function! It’s just a matter of combining each word in our input with the appropriate number of spaces. In this case, almost every word gets 1 space except for the last one (which comes first in our list). This just gets however many trailing spaces we need!
makeLineFinal :: [String] -> Int -> Int -> String
makeLineFinal [] _ _ = error "Cannot makeLine with empty string!"
makeLineFinal strs wordsInLine charsInLine =
let trailingSpaces = maxWidth - charsInLine - (wordsInLine - 1)
in foldl combine "" (zip strs (trailingSpaces : repeat 1))
Putting all these pieces together, we have our complete solution!
justifyText :: [String] -> Int -> [String]
justifyText inputWords maxWidth =
reverse (makeLineFinal flLine fWordsInLine fCharsInLine : result)
where
(fLine, fWordsInLine, fCharsInLine, result) = foldl loop ([], 0, 0, []) inputWords
loop :: ([String], Int, Int, [String]) -> String -> ([String], Int, Int, [String])
loop (currentLine, wordsInLine, charsInLine, currResult) newWord =
let newWordLen = length newWord
resultLine = makeLine currentLine wordsInLine charsInLine
in if newWordLen + charsInLine + wordsInLine > maxWidth
then ([newWord], 1, newWordLen, resultLine : currResult)
else (newWord : currentLine, wordsInLine + 1, charsInLine + newWordLen, currResult)
makeLine :: [String] -> Int -> Int -> String
makeLine [] _ _ = error "Cannot makeLine with empty string!"
makeLine [onlyWord] _ charsInLine =
let extraSpaces = replicate (maxWidth - charsInLine) ' '
in onlyWord <> extraSpaces
makeLine (first : rest) wordsInLine charsInLine =
let (baseNumSpaces, numWithExtraSpace) = quotRem (maxWidth - charsInLine) (wordsInLine - 1)
baseSpaces = replicate (wordsInLine - 1 - numWithExtraSpace) baseNumSpaces
extraSpaces = replicate numWithExtraSpace (baseNumSpaces + 1)
wordsWithSpaces = zip rest (baseSpaces <> extraSpaces)
in foldl combine first wordsWithSpaces
makeLineFinal :: [String] -> Int -> Int -> String
makeLineFinal [] _ _ = error "Cannot makeLine with empty string!"
makeLineFinal strs wordsInLine charsInLine =
let trailingSpaces = maxWidth - charsInLine - (wordsInLine - 1)
in foldl combine "" (zip strs (trailingSpaces : repeat 1))
combine :: String -> (String, Int) -> String
combine suffix (nextWord, numSpaces) = nextWord <> replicate numSpaces ' ' <> suffix
Rust Solution
Now let’s put together our Rust solution. Since we have a reasonable outline from writing this in Haskell, let’s start with the simpler elements, makeLine
and makeLineFinal
. We’ll use library functions as much as possible for the string manipulation. For example, we can start makeLineFinal
by using join
on our input vector of strings.
pub fn make_line_final(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = currentLine.join(" ");
...
}
Now we just need to calculate the number of trailing spaces, subtracting the number of characters in the joined string. We append this to the end by taking a blank space and using repeat
for the correct number of times.
pub fn make_line_final(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = currentLine.join(" ");
let trailingSpaces = max_width - result.len();
result.push_str(&" ".repeat(trailingSpaces));
return result;
}
For those unfamiliar with Rust, the type of our input vector might seem odd. When we have &Vec<&str>
, this means a reference to a vector of string slices. String slices are portions of a String
that we hold a reference to, but they aren’t copied. However, when we join
them, we make a new String
result.
Also note that we aren’t passing wordsInLine
as a separate parameter. We can get this value using .len()
in constant time in Rust. In Haskell, length
is O(n)
so we don’t want to always do that.
Now for the general make_line
function, we have the same type signature, but we start with our base case, where we only have one string in our current line. Again, we use repeat
with the number of spaces.
pub fn make_line(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = String::new();
let n = currentLine.len();
if (n == 1) {
result.push_str(currentLine[0]);
result.push_str(&" ".repeat(max_width - charsInLine));
return result;
}
...
}
Now we do the “math” portion of this. Rust doesn’t have a single quotRem
function in its base library, so we calculate these values separately.
pub fn make_line(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = String::new();
let n = currentLine.len();
if (n == 1) {
result.push_str(currentLine[0]);
result.push_str(&" ".repeat(max_width - charsInLine));
return result;
}
let numSpaces = (max_width - charsInLine);
let baseNumSpaces = numSpaces / (n - 1);
let numWithExtraSpace = numSpaces % (n - 1);
let mut i = 0;
while i < n {
...
}
return result;
}
The while
loop we’ll write here is instructive. We use an index instead of a for each
pattern because the index tells us how many spaces to use. If our index is smaller than numWithExtraSpace
, we add 1 to the base number of spaces. Otherwise we use the base until the index n - 1
. This index has no extra spaces, so we’re done at that point!
pub fn make_line(
currentLine: &Vec<&str>,
max_width: usize,
charsInLine: usize) -> String {
let mut result = String::new();
let n = currentLine.len();
if (n == 1) {
result.push_str(currentLine[0]);
result.push_str(&" ".repeat(max_width - charsInLine));
return result;
}
let numSpaces = (max_width - charsInLine);
let baseNumSpaces = numSpaces / (n - 1);
let numWithExtraSpace = numSpaces % (n - 1);
let mut i = 0;
while i < n {
result.push_str(currentLine[i]);
if i < numWithExtraSpace {
result.push_str(&" ".repeat(baseNumSpaces + 1));
} else if i < n - 1 {
result.push_str(&" ".repeat(baseNumSpaces));
}
i += 1;
}
return result;
}
Now we frame our solution. Let’s start by setting up our state variables (again, omitting numWordsInLine
). We’ll also redefine max_width
as a usize
value for ease of comparison later.
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
...
}
Now we’d like to frame our solution as a “for each” loop. However, this doesn’t work, for Rust-related reasons we’ll describe after the solution! Instead, we’ll use an index loop.
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
...
}
}
We’ll get the word
by index on each iteration, and use its length to see if we’ll exceed the max width. If not, we can safely push it onto currentLine
and increase the character count:
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
let word = &words[i];
if word.len() + charsInLine + currentLine.len() > mw {
...
} else {
currentLine.push(&words[i]);
charsInLine += word.len();
}
}
}
Now when we do exceed the max width, we have to push our current line onto result
(calling make_line
). We clear the current line, push our new word, and use its length for charsInLine
.
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
let word = &words[i];
if word.len() + charsInLine + currentLine.len() > mw {
result.push(make_line(¤tLine, mw, charsInLine));
currentLine.clear();
currentLine.push(&words[i]);
charsInLine = word.len();
} else {
currentLine.push(&words[i]);
charsInLine += word.len();
}
}
...
}
After our loop, we’ll just call make_line_final
on whatever is left in our currentLine
! Here’s our complete full_justify
function that calls make_line
and make_line_final
as we wrote above:
pub fn full_justify(words: Vec<String>, max_width: i32) -> Vec<String> {
let mut currentLine = Vec::new();
let mut charsInLine = 0;
let mut result = Vec::new();
let mw = max_width as usize;
let mut i = 0;
let n = words.len();
for i in 0..n {
let word = &words[i];
if word.len() + charsInLine + currentLine.len() > mw {
result.push(make_line(¤tLine, mw, charsInLine));
currentLine.clear();
currentLine.push(&words[i]);
charsInLine = word.len();
} else {
currentLine.push(&words[i]);
charsInLine += word.len();
}
}
result.push(make_line_final(¤tLine, mw, charsInLine));
return result;
}
Why an Index Loop?
Inside our Rust loop, we have an odd pattern in getting the “word” for this iteration. We first assign word = &words[i]
, and then later on, when we push that word, we reference words[i]
again, using currentLine.push(&words[i])
.
Why do this? Why not currentLen.push(word)
? And then, why can’t we just do for word in words
as our loop?
If we write our loop as for word in words
, then we cannot reference the value word
after the loop. It is “scoped” to the loop. However, currentLine
“outlives” the loop! We have to reference currentLine
at the end when we make our final line.
To get around this, we would basically have to copy the word instead of using a string reference &str
, but this is unnecessarily expensive.
These are the sorts of odd “lifetime” quirks you have to learn to deal with in Rust. Haskell is easier in that it spares us from thinking about this. But Rust gains a significant performance boost with these sorts of ideas.
Conclusion
This was definitely the most involved problem we’ve dealt with so far. We learned a new paradigm (buffer and save), and got some experience dealing with some of the odd quirks and edge cases of string manipulation, especially in Rust. It was a fairly tricky problem, as far as list manipulation goes. For an easier example of a buffer and save problem, try solving Merge Intervals.
If you want to level up your Haskell problem solving skills, you need to take our course Solve.hs. This course will teach you everything you need to know about problem solving, data structures, and algorithms in Haskell. After this course, you’ll be in great shape to deal with these sorts of LeetCode style problems as they come up in your projects.
The Sliding Window in Haskell & Rust
In last week’s problem, we covered a two-pointer algorithm, and compared Rust and Haskell solutions as we have been for this whole series. Today, we’ll study a related concept, the sliding window problem. Whereas the general two-pointer problem can often be tackled by a single loop, we’ll have to use nested loops in this problem. This problem will also mark our first use of the Set
data structure in this series.
If you want a deeper look at problem solving techniques in Haskell, you should enroll in our Solve.hs course! You’ll learn everything you need for general problem solving knowledge in Haskell, including data structures, algorithms, and parsing!
The Problem
Today’s LeetCode problem is Longest Substring without Repeating Characters. It’s a lengthy problem name, but the name basically tells you everything you need to know! We want to find a substring of our input that does not repeat any characters within the substring, and then get the longest such substring.
For example, abaca
would give us an answer of 3, since we have the substringbac
that consists of 3 unique characters. However, abaaca
only gives us 2. There is no run of 3 characters where the three characters are all unique.
The Algorithm
The approach we’ll use, as mentioned above, is called a sliding window algorithm. In some ways, this is similar to the two-pointer approach last week. We’ll have, in a sense, two different pointers within our input. One dictates the “left end” of a window and one dictates the “right end” of a window. Unlike last week’s problem though, both pointers will move in the same direction, rather than converging from opposite directions.
The goal of a sliding window problem is “find a continuous subsequence of an input that matches the criteria”. And for many problems like ours, you want to find the longest such subsequence. The main process for a sliding window problem is this:
- Grow the window by increasing the “right end” until (or while) the predicate is satisfied
- Once you cannot grow the window any more, shrink the window by increasing the “left end” until we’re in a position to grow the window again.
- Continue until one or both pointers go off the end of the input list.
So for our problem today, we want to “grow” our sliding window as long as we can get more unique characters. Once we hit a character we’ve already seen in our current window, we’ll need to shrink the window until that duplicate character is removed from the set.
As we’re doing this, we’ll need to keep track of the largest substring size we’ve seen so far.
Here are the steps we would take with the input abaca
. At each step, we process a new input character.
1. Index 0 (‘a’) - window is “a” which is all unique.
2. Index 1 (‘b’) - window is “ab” which is all unique
3. Index 2 (‘a’) - window is “aba”, which is not all unique
3b. Shrink window, removing first ‘a’, so it is now “ba”
4. Index 3 (‘c’) - window is “bac”, which is all unique
5. Index 4 (‘a’) - window is “baca”, which is not unique
5b. Shrink window, remove ‘b’ and ‘a’, leaving “ca”
The largest unique window we saw was bac
, so the final answer is 3.
Haskell Solution
For a change of pace, let’s discuss the Haskell approach first. Our algorithm is laid out in such a way that we can process one character at a time. Each character either grows the window, or forces it to shrink to accommodate the character. This means we can use a fold!
Let’s think about what state we need to track within this fold. Naturally, we want to track the current “set” of characters in our window. Each time we see the next character, we have to quickly determine if it’s already in the window. We’ll also want to track the largest set size we’ve seen so far, since by the end of the string our window might no longer reflect the largest subsequence.
With a general sliding window approach, you would also need to track both the start and the end index of your current window. In this problem though, we can get away with just tracking the start index. We can always derive the end index by taking the start index and adding the size of the set. And since we’re iterating through the characters anyway, we don’t need the end index to get the “next” character.
This means our fold-loop function will have this type signature:
-- State: (start index, set of letters, largest seen)
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
Now, using our idea of “beginning from the end”, we can already write the invocation of this loop:
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
...
Using 0
for the start index right away is a little hand-wavy, since we haven’t actually added the first character to our set yet! But if we see a single character, we’ll always add it, and as we’ll see, the “adding” branch of our loop never increases this number.
With that in mind, let’s write this branch of our loop handler! If we have not seen the next character in the string, we keep the same start index (left side of the window isn’t moving), we add the character to our set, and we take the new size of the set as the “best” value if it’s greater than the original. We get the new size by adding 1 to the original set size.
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
loop (startIndex, charSet, bestSoFar) c = if S.notMember c charSet
then (startIndex, S.insert c charSet, max bestSoFar (S.size charSet + 1))
else ...
Now we reach the tricky case! If we’ve already seen the next character, we need to remove characters from our set until we reach the instance of this character in the set. Since we might need to remove multiple characters, “shrinking” is an iterative process with a variable number of steps. This means it would be a while-loop in most languages, which means we need another recursive function!
The goal of this function is to change two of our stateful values (the start index and the character set) until we can once again have a unique character set with the new input character. So each iteration it takes the existing values for these, and will ultimately return updated values. Here’s its type signature:
shrink :: (Int, S.Set Char) -> Char -> (Int, S.Set Char)
Before we implement this, we can invoke it in our primary loop! When we’ve seen the new character in our set, we shrink the input to match this character, and then return these new stateful values along with our previous best (shrinking never increases the size).
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
loop (startIndex, charSet, bestSoFar) c = if S.notMember c charSet
then (startIndex, S.insert c charSet, max bestSoFar (S.size charSet + 1))
else
let (newStart, newSet) = shrink (startIndex, charSet) c
in (newStart, newSet, bestSoFar)
shrink :: (Int, S.Set Char) -> Char -> (Int, S.Set Char)
shrink = undefined
Now we implement “shrink” by considering the base case and recursive case. In the base case, the character at this index matches the new character we’ve trying to remove. So we can return the same set of characters, but increase the index.
In the recursive case, we still increase the index, but now we remove the character at the start index from the set without replacement. (Note how we need a vector for efficient indexing here).
largestUniqueSubsequence :: String -> Int
largestUniqueSubsequence input = best
where
(_, _, best) = foldl loop (0, S.empty, 0) input
loop :: (Int, S.Set Char, Int) -> Char -> (Int, S.Set Char, Int)
loop (startIndex, charSet, bestSoFar) c = if S.notMember c charSet
then (startIndex, S.insert c charSet, max bestSoFar (S.size charSet + 1))
else
let (newStart, newSet) = shrink (startIndex, charSet) c
in (newStart, newSet, bestSoFar)
shrink :: (Int, S.Set Char) -> Char -> (Int, S.Set Char)
shrink (startIndex, charSet) c =
let nextC = inputV V.! startIndex
// Base Case: nextC is equal to newC
in if nextC == c then (startIndex + 1, charSet)
// Recursive Case: Remove startIndex
else shrink (startIndex + 1, S.delete nextC charSet) c
Now we have a complete Haskell solution!
Rust Solution
Now in our Rust solution, we’ll follow the same pattern we’ve been doing for these problems. We’ll set up our loop variables, write the loop, and handle the different cases in the loop. Because we had the nested recursive “shrink” function in Haskell, this will translate to a “while” loop in Rust, nested within our for-loop.
Here’s how we set up our loop variables:
pub fn length_of_longest_substring(s: String) -> i32 {
let mut best = 0;
let mut startIndex = 0;
let inputV: Vec<char> = s.chars().collect();
let mut charSet = HashSet::new();
for c in s.chars() {
...
}
}
Within the loop, we have the “easy” case, where the next character is not already in our set. We just insert it into our set, and we update best
if we have a new maximum.
pub fn length_of_longest_substring(s: String) -> i32 {
let mut best = 0;
let mut startIndex = 0;
let inputV: Vec<char> = s.chars().collect();
let mut charSet = HashSet::new();
for c in s.chars() {
if charSet.contains(&c) {
...
} else {
charSet.insert(c);
best = std::cmp::max(best, charSet.len());
}
}
return best as i32;
}
The Rust-specific oddity is that when we call contains
on the HashSet
, we must use &c
, passing a reference to the character. In C++ we could just copy the character, or it could be handled by the function using const&
. But Rust handles these things a little differently.
Now we get to the “tricky” case within our loop. How do we “shrink” our set to consume a new character?
In our case, we’ll actually just use the loop
functionality of Rust, which works like while (true)
, requiring a manual break
inside the loop. Our idea is that we’ll inspect the character at the “start” index of our window. If this character is the same as the new character, we will advance the start index (indicating we are dropping the old version), but then we’ll break
. Otherwise, we’ll still increase the index, but we’ll remove the other character from the set as well.
Here’s what this loop looks like in relative isolation:
if charSet.contains(&c) {
loop {
// Look at “first” character of window
let nextC = inputV[startIndex];
if (nextC == c) {
// If it’s the new character, we advance past it and break
startIndex += 1;
break;
} else {
// Otherwise, advance AND drop it from the set
startIndex += 1;
charSet.remove(&nextC);
}
}
} else {
...
}
The inner condition (nextC == c
) feels a little flimsy to use with a while (true)
loop. But it’s perfectly sound because of the invariant that if charSet
contains c
, we’ll necessarily find nextC == c
before startIndex
gets too large. We could also write it as a normal while
loop, but loop
is an interesting Rust-specific idea to bring in here.
Here’s our complete Rust solution!
pub fn length_of_longest_substring(s: String) -> i32 {
let mut best = 0;
let mut startIndex = 0;
let inputV: Vec<char> = s.chars().collect();
let mut charSet = HashSet::new();
for c in s.chars() {
if charSet.contains(&c) {
loop {
let nextC = inputV[startIndex];
if (nextC == c) {
startIndex += 1;
break;
} else {
startIndex += 1;
charSet.remove(&nextC);
}
}
} else {
charSet.insert(c);
best = std::cmp::max(best, charSet.len());
}
}
return best as i32;
}
Conclusion
With today’s problem, we’ve covered another important problem-solving concept: the sliding window. We saw how this approach could work even with a fold in Haskell, considering one character at a time. We also saw how nested loops compare across Haskell and Rust.
For more problem solving tips and tricks, take a look at Solve.hs, our complete course on problem solving, data structures, and algorithms in Haskell. You’ll get tons of practice on problems like these so you can significantly level up your skills!
Two Pointer Algorithms
We’re now on to part 5 of our series comparing Haskell and Rust solutions for LeetCodeproblems. You can also look at the previous parts (Part 1, Part 2, Part 3, Part 4) to get some more context on what we’ve learned so far comparing these two languages.
For a full look at problem solving in Haskell, check out Solve.hs, our latest course! You’ll get full breakdowns on the processes for solving problems in Haskell, from basic list and loop problems to advanced algorithms!
The Problem
Today we’ll be looking at a problem called Trapping Rain Water. In this problem, we’re given a vector of heights, which form a sort of 1-dimensional topology. Our job is to figure out how many units of water could be collected within the topology.
As a very simple example, the input [1,0,2]
could collect 1 unit of water. Here’s a visualization of that system, where x
shows the topology and o
shows water we collect:
x
xox
We can never collect any water over the left or right “edges” of the array, since it would flow off. The middle index of our array though is lower than its neighbors. So we take the lower of these neighboring values, and we see that we can collect 1
unit of water in this system.
For a bigger example that collects water, we might have the input [4, 2, 1, 1, 3, 5]
. Here’s what that looks like:
x
x o o o o x
x o o o x x
x x o o x x
x x x x x x
The total water here is 9.
A flat system like [2,2,2]
, or a system that looks like a peak [1,2,3,2,1]
cannot collect any water, so we should return 0 in these cases.
The Algorithm
There are a couple ways to solve this. One approach would be a two-pass solution, similar to what we used in Product of Array Except Self. We loop from the left side, tracking the maximum water we can store in each unit based on its left neighbors. Then we loop again from the right side and compare the maximum we can store based on the right neighbors to the prior value from the left. This solution is O(n)
time, but O(n)
space as well.
A more optimal solution for this problem is a two-pointer approach that can use O(1)
additional space. In this kind of solution, we look at the left and right of the input simultaneously. Each step of the way, we make a decision to either increase the “left pointer” or decrease the “right pointer” until they meet in the middle. Each time we move, we get more information about our solution.
In this particular problem, we’ll track the maximum value we’ve seen from the left side and the maximum value we’ve seen from the right side. As we traverse each index, we update both sides for the current left and right indices if we have a new maximum.
The crucial step is to see that if the current “left max” is smaller than the current “right max”, we know how much water can be stored at the left index. This is just the left max minus the left index. Then we can increment the left index.
If the opposite is true, we calculate how much water can be stored at the right index, and decrease the right index.
So we keep a running tally of these sums, and we end our loop when they meet in the middle.
Rust Solution
We can describe our algorithm as a simple while loop. This loop goes until the left index exceeds the right index. The loop needs to track 5 values:
- Left Index
- Right Index
- Left Max
- Right Max
- Total sum so far
So let’s write the setup portion of the loop:
pub fn trap(height: Vec<i32>) -> i32 {
let mut leftMax = -1;
let mut rightMax = -1;
let mut leftI = 0;
let mut rightI = height.len() - 1;
let mut total = 0;
while leftI <= rightI {
...
}
}
A subtle thing…the constraints on the LeetCode problem are that the length is at least 1. But to handle length 0 cases, we would need a special case. Rust uses unsigned integers for vector length, so taking height.len() - 1
on a length-0 vector would give the maximum integer, and this would mess up our loop and indexing.
Within the while
loop, we run the algorithm.
- Adjust
leftMax
andrightMax
if necessary. - If
leftMax
is not larger, recurse, incrementingleftI
and adding tototal
from the left - If
rightMax
is smaller, decrementrightI
and add total from the right
And at the end, we return our total
!
pub fn trap(height: Vec<i32>) -> i32 {
let n = height.len();
if n <= 1 {
return 0;
}
let mut leftMax = -1;
let mut rightMax = -1;
let mut leftI = 0;
let mut rightI = n - 1;
let mut total = 0;
while leftI <= rightI {
// Step 1
leftMax = std::cmp::max(leftMax, height[leftI]);
rightMax = std::cmp::max(rightMax, height[rightI]);
if leftMax <= rightMax {
// Step 2
total += leftMax - height[leftI];
leftI += 1;
} else {
// Step 3
total += rightMax - height[rightI];
rightI -= 1;
}
}
return total;
}
Haskell Solution
Now that we’ve seen our Rust solution with a single loop, let’s remember our process for translating this idea to Haskell. With a two-pointer loop, the way in which we traverse the elements of the input is unpredictable, thus we need a raw recursive function, rather than a fold or a map.
Since we’re tracking 5 integer values, we’ll want to write a loop
function that looks like this:
-- (leftIndex, rightIndex, leftMax, rightMax, sum)
loop :: (Int, Int, Int, Int, Int) -> Int
Knowing this, we can already “start from the end” and figure out how to invoke our loop from the start of our function:
trapWater :: V.Vector Int -> Int
trapWater input = loop (0, n - 1, -1, -1, 0)
where
n = V.length input
loop :: (Int, Int, Int, Int, Int) -> Int
loop = undefined
In writing our recursive loop, we’ll start with the base case. Once leftI
is the bigger index, we return the total.
trapWater :: V.Vector Int -> Int
trapWater input = loop (0, n - 1, -1, -1, 0)
where
n = V.length input
loop :: (Int, Int, Int, Int, Int) -> Int
loop (leftI, rightI, leftMax, rightMax, total) = if leftI > rightI then total
else …
Within the else
case, we just follow our algorithm, with the same 3 steps we saw with Rust.
trapWater :: V.Vector Int -> Int
trapWater input = loop (0, n - 1, -1, -1, 0)
where
n = V.length input
-- (leftIndex, rightIndex, leftMax, rightMax, sum)
loop :: (Int, Int, Int, Int, Int) -> Int
loop (leftI, rightI, leftMax, rightMax, total) = if leftI > rightI then total
else
-- Step 1
let leftMax' = max leftMax (input V.! leftI)
rightMax' = max rightMax (input V.! rightI)
in if leftMax' <= rightMax'
-- Step 2
then loop (leftI + 1, rightI, leftMax', rightMax', total + leftMax' - input V.! leftI)
-- Step 3
else loop (leftI, rightI - 1, leftMax', rightMax', total + rightMax' - input V.! rightI)
And we have our Haskell solution!
Conclusion
If you’ve been following this whole series so far, hopefully you’re starting to get a feel for comparing basic algorithms in Haskell and Rust (standing as a proxy for most loop-based languages). In general, we can write loops as recursive functions in Haskell, capturing the “state” of the list as the input parameter for that function.
In particular cases where each iteration deals with exactly one element of an input list, we can employ folds as a tool to simplify our functions. But the two-pointer algorithm we explored today falls into the general recursive category.
To learn the details of understanding these problem solving techniques, take a look at our course, Solve.hs! You’ll learn everything from basic loop and list techniques, to advanced data structures and algorithms!
Spatial Reasoning with Zigzag Patterns!
Today we’re continuing our study of Rust and Haskell solutions to basic coding problems. This algorithm is going to be a little harder than the last few we’ve done in this series, and it will get trickier from here!
For a complete study of problem solving techniques in Haskell, make sure to check out Solve.hs. This course runs the gamut from basic solving techniques to advanced data structures and algorithms, so you’ll learn a lot!
The Problem
Today’s problem is Zigzag Conversion. This is an odd problem that stretches your ability to think iteratively and spatially. The idea is that you’re given an input string and a number of “rows”. You need to then imagine the input word written as a zig-zag pattern, where you write the letters in order first going down, and then diagonally up to the right until you get back to the first row. Then it goes down again. Your output must be characters re-ordered in “row-order” after this zig-zag rearrangement.
This makes the most sense looking at examples. Let’s go through several variations with the string MONDAYMORNINGHASKELL
. Here’s what it looks like with 3 rows.
M A R G K
O D Y O N N H S E L
N M I A L
So to get the answer, we read along the top line first (MARGK
), then the second (ODYONNHSEL
), and then the third (NMIAL
). So the final answer is MARGKODYONNHSELNMIAL
.
Now let’s look at the same string in 4 rows:
M M G L
O Y O N H E L
N A R I A K
D N S
The answer here is MMGLOYONHELNARIAKDNS
.
Here’s 5 rows:
M R K
O O N S E
N M I A L
D Y N H L
A G
The answer here is MRKOONSENMIALDYNHLAG
.
And now that we have the pattern, we can also consider 2 rows, which doesn’t visually look like a zig-zag as much:
M N A M R I G A K L
O D Y O N N H S E L
This gives the answer MNAMRIGAKLODYONNHSEL
.
Finally, if there’s only 1 row, you can simply return the original string.
The Algorithm
So how do we go about solving this? The algorithm here is a bit more involved than the last few weeks!
Our output order is row-by-row, so for our solution we should think in a row-by-row fashion. If we can devise a function that will determine the indices of the original string that belong in each row, then we can simply loop over the rows and append these results!
In order to create this function, we have to think about the zig-zag in terms of “cycles”. Each cycle begins at the top row, goes down to the bottom row, and then up diagonally to the second row. The next element to go at the top row starts a new cycle. By thinking about cycles, we’ll discover a few key facts:
- With
n
rows (n >= 2), a complete cycle has2n - 2
letters. - The top and bottom row get one letter per cycle.
- All other rows get two letters per cycle.
Now we can start to think mathematically about the indices that belong in each row. It’s easiest to think about the top and bottom rows, since they only get one letter each cycle. Each of these has a starting index (0
and n - 1
, respectively), and then we add the cycle length 2n - 2
to these starting indices until it exceeds the length.
The middle rows have this same pattern, only now they have 2 starting indices. They have the starting index from the “down” direction and then their first index going up and to the right. The first index for row i
is obviously i - 1
, but the second index is harder to see.
The easiest way to find the second index is backwards! The next cycle starts at 2n - 2
. So row index 1
has its second index at 2n - 2 - 1
, and row index 2
has its second index at 2n - 2 - 2
, and so on! The pattern of adding the “cycle number” will work for all starting indices.
Once we have the indices for each row, our task is simple. We build a string for each row and combine them together in order.
So suppose we have our 4-row example.
M M G L
O Y O N H E L
N A R I A K
D N S
The “cycle num” is 6 (2 * 4 - 2). So the first row has indices [0, 6, 12, 18]
. The fourth row starts with index 3, and so its indices also go up by 6 each time: [3, 9, 15]
.
The second row (index 1) has starting indices 1 and 5 (6 - 1
). So its indices are [1, 5, 7, 11, 13, 17, 19]
. Then the third row has indices [2, 4, 8, 10, 14, 16]
.
A vector input will allow us to efficiently use and combine these indices.
As a final note, the “cycle num” logic doesn’t end up working with only 1 row. The cycle length using our calculation would be 0, not 1 as it should. The discrepancy is because our “cycle num” logic really depends on having a “first” and “last” row. So if we only have 1 row, we’ll hardcode that case and return the input string.
Rust Solution
In our rust solution, we’ll accumulate our result string in place. To accomplish this we’ll do a few setup steps:
- Handle our base case (1 row)
- Get the string length and cycle number
- Make a vector of the input chars for easy indexing (Rust doesn’t allow string indexing)
- Initialize our mutable result string
pub fn convert(s: String, num_rows: i32) -> String {
if (num_rows == 1) {
return s;
}
let n = s.len();
let nr = num_rows as usize; // Convenience for comparison
let cycleLen: usize = (2 * nr - 2);
let sChars: Vec<char> = s.chars().collect();
let mut result = String::new();
...
}
Now we have to add the rows in order. Since the logic differs for the first and last rows, we have 3 sections: first row, middle rows, and last row. The first and last row are straightforward using our algorithm. Each is a simple while loop.
pub fn convert(s: String, num_rows: i32) -> String {
if (num_rows == 1) {
return s;
}
let n = s.len();
let nr = num_rows as usize; // Convenience for comparison
let cycleLen: usize = (2 * nr - 2);
let sChars: Vec<char> = s.chars().collect();
let mut result = String::new();
// First Row
let mut i = 0;
while i < n {
result.push(sChars[i]);
i += cycleLen;
}
// Middle Rows
...
// Last Row
i = (nr - 1);
while i < n {
result.push(sChars[i]);
i += cycleLen;
}
return result;
}
Now the middle rows section is similar. We loop through each of the possible rows in the middle. For each of these, we’ll do a while loop similar to the first and last row. These loops are different though, because we have to track two possible values, the “first” and “second” of each cycle.
If the “first” is already past the end of the vector, then we’re already done and can skip the loop. But even if not, we still need an “if check” on the “second” value as well. Each time through the loop, we increase both values by cycleLen
.
pub fn convert(s: String, num_rows: i32) -> String {
if (num_rows == 1) {
return s;
}
let n = s.len();
let nr = num_rows as usize; // Convenience for comparison
let cycleLen: usize = (2 * nr - 2);
let sChars: Vec<char> = s.chars().collect();
let mut result = String::new();
// First Row
let mut i = 0;
while i < n {
result.push(sChars[i]);
i += cycleLen;
}
// Middle Rows
for row in 1..(nr - 1) {
let mut first = row;
let mut second = cycleLen - row;
while first < n {
result.push(sChars[first]);
if second < n {
result.push(sChars[second]);
}
first += cycleLen;
second += cycleLen;
}
}
// Last Row
i = (nr - 1);
while i < n {
result.push(sChars[i]);
i += cycleLen;
}
return result;
}
And that’s our complete solution!
Haskell Solution
The Haskell solution follows the same algorithm, but we’ll make a few stylistic changes compared to Rust. In Haskell, we’ll go ahead and define specific lists of indices for each row. That way, we can combine these lists and make our final string all at once using concatMap
. This approach will let us demonstrate the power of ranges in Haskell.
We start our defining our base case and core parameters:
zigzagConversion :: String -> Int -> String
zigzagConversion input numRows = if numRows == 1 then input
else ...
where
n = length input
cycleLen = 2 * numRows - 2
...
Now we can define index-lists for the first and last rows. These are just ranges! We have the starting element, and we know to increment it by cycleLen
. The range should go no higher than n - 1
. Funny enough, the range can figure out that it should be empty in the edge case that our input is too small to fill all the rows!
zigzagConversion :: String -> Int -> String
zigzagConversion input numRows = if numRows == 1 then input
else ...
where
n = length input
cycleLen = 2 * numRows - 2
firstRow :: [Int]
firstRow = [0,cycleLen..n - 1]
lastRow :: [Int]
lastRow = [numRows - 1, numRows - 1 + cycleLen..n - 1]
...
In Rust, we used a while-loop with two state values to calculate the middle rows. Hopefully you know from this series now that this while loop translates into a recursive function in Haskell. We’ll accumulate our list of indices as a tail argument, and keep the two stateful values as our other input parameters. We’ll combine all our lists together into one big list of int-lists, allRows
.
zigzagConversion :: String -> Int -> String
zigzagConversion input numRows = if numRows == 1 then input
else ...
where
n = length input
cycleLen = 2 * numRows - 2
firstRow :: [Int]
firstRow = [0,cycleLen..n - 1]
lastRow :: [Int]
lastRow = [numRows - 1, numRows - 1 + cycleLen..n - 1]
middleRow :: Int -> Int -> [Int] -> [Int]
middleRow first second acc = if first >= n then reverse acc
else if second >= n then reverse (first : acc)
else middleRow (first + cycleLen) (second + cycleLen) (second : first : acc)
middleRows :: [[Int]]
middleRows = map (\i -> middleRow i (cycleLen - i) []) [1..numRows-2]
allRows :: [[Int]]
allRows = firstRow : middleRows <> [lastRow]
...
Now we bring it all together with one final step. We make a vector from our input, and define a function to turn a single int-list into a single String. Then at the top level of our function (the original else
branch), we use concatMap
to bring these together into our final result String.
zigzagConversion :: String -> Int -> String
zigzagConversion input numRows = if numRows == 1 then input
else concatMap rowIndicesToString allRows
where
n = length input
cycleLen = 2 * numRows - 2
firstRow :: [Int]
firstRow = [0,cycleLen..n - 1]
lastRow :: [Int]
lastRow = [numRows - 1, numRows - 1 + cycleLen..n - 1]
middleRow :: Int -> Int -> [Int] -> [Int]
middleRow first second acc = if first >= n then reverse acc
else if second >= n then reverse (first : acc)
else middleRow (first + cycleLen) (second + cycleLen) (second : first : acc)
middleRows :: [[Int]]
middleRows = map (\i -> middleRow i (cycleLen - i) []) [1..numRows-2]
allRows :: [[Int]]
allRows = firstRow : middleRows <> [lastRow]
inputV :: V.Vector Char
inputV = V.fromList input
rowIndicesToString :: [Int] -> String
rowIndicesToString = map (inputV V.!)
Conclusion
This comparison once again showed how while loops in Rust track with recursive functions in Haskell. We also saw some nifty Haskell features like ranges and tail recursion. Most of all, we saw that even with a trickier algorithm, we can still keep the same basic shape of our algorithm in a functional or imperative style.
To learn more about these problem solving concepts, take a look at Solve.hs, our comprehensive course on problem solving in Haskell. You’ll learn about recursion, list manipulation, data structures, graph algorithms, and so much more!
Starting from the End: Solving “Product Except Self”
Today we continue our series exploring LeetCode problems and comparing Haskell and Rust solutions. We’re staying in the realm of list/vector manipulation, but the problems are going to start getting more challenging!
If you want to learn more about problem solving in Haskell, you should take a closer look at Solve.hs! You’ll particularly learn how to translate common ideas from loop-based into Haskell’s recursive ideas!
The Problem
Today’s problem is Product of Array Except Self. The idea is that we are given a vector of n integers. We are supposed to return another vector of n integers, where output[i]
is equivalent to the product of all the input integers except for input[i]
.
The key constraint here is that we are not allowed to use division. If we could use division, the answer would be simple! We would find the product of the input numbers and then divide this product by each input number to find the corresponding value. But division is more expensive than most other numeric operations, so we want to avoid it if possible!
The Algorithm
The approach we’ll use in this article relies on “prefix products” and “suffix products”. We’ll make two separate vectors called prefixes
and suffixes
, where prefixes[i]
is the product of all numbers strictly before index i
, and suffixes[i]
is the product of all numbers strictly after index i
.
Then, we can easily produce our results. The value output[i]
is simply the product of prefixes[i]
and suffixes[i]
.
As an example, our input might be [3, 4, 5]
. The prefixes
vector should be [1, 3, 12]
, and the suffixes
vector should be [20, 5, 1]
. Then our final output should be [20, 15, 12]
.
prefixes: [1, 3, 12]
suffixes: [20, 5, 1]
output: [20, 15, 12]
Rust Solution
Here’s our Rust solution:
impl Solution {
pub fn product_except_self(nums: Vec<i32>) -> Vec<i32> {
let n = nums.len();
let mut prefixes = vec![0; n];
let mut suffixes = vec![0; n];
let mut totalPrefix = 1;
let mut totalSuffix = 1;
// Loop 1: Populate prefixes & suffixes
for i in 0..n {
prefixes[i] = totalPrefix;
totalPrefix *= nums[i];
suffixes[n - i - 1] = totalSuffix;
totalSuffix *= nums[n - i - 1];
}
let mut results = vec![0; n];
// Loop 2: Populate results
for i in 0..n {
results[i] = prefixes[i] * suffixes[i];
}
return results;
}
}
The two for-loops provide this solution with its shape. The first loop generates our vectors prefixes
and suffixes
. We keep track of a running tally of the totalPrefix
and the totalSuffix
. Each of these is initially 1.
let n = nums.len();
let mut prefixes = vec![0; n];
let mut suffixes = vec![0; n];
let mut totalPrefix = 1;
let mut totalSuffix = 1;
On each iteration, we assign the current “total prefix” to the prefixes
vector in the front index i
, and then the “total suffix” to the suffixes
vector in the back index n - i - 1
. Then we multiply each total value by the input value (nums
) from that index so it’s ready for the next iteration.
// Loop 1: Populate prefixes & suffixes
for i in 0..n {
prefixes[i] = totalPrefix;
totalPrefix *= nums[i];
suffixes[n - i - 1] = totalSuffix;
totalSuffix *= nums[n - i - 1];
}
And now we calculate the result, by taking the product of prefixes
and suffixes
at each index.
let mut results = vec![0; n];
// Loop 2: Populate results
for i in 0..n {
results[i] = prefixes[i] * suffixes[i];
}
return results;
Haskell Solution
In Haskell, we can follow this same template. However, a couple differences stand out. First, we don’t use for-loops. We have to use recursion or recursive helpers to accomplish these loops. Second, when constructing prefixes
and suffixes
, we want to use lists instead of modifying mutable vectors.
When performing recursion and accumulating linked lists, it can be tricky to reason about which lists need to be reversed at which points in our algorithm. For this reason, it’s often very helpful in Haskell to start from the end of our algorithm.
Let’s write out a template of our solution that leaves prefixes
and suffixes
as undefined stubs. Then the first step we’ll work through is how to get the solution from that:
productOfArrayExceptSelf :: V.Vector Int -> V.Vector Int
productOfArrayExceptSelf inputs = solution ???
where
n = V.length inputs
solution :: ??? -> V.Vector Int
prefixes :: [Int]
prefixes = undefined
suffixes :: [Int]
suffixes = undefined
So given prefixes
and suffixes
, how do we find our solution? The ideal case is that both these lists are already in reverse-index order with respect to the input vector (i.e. n - 1
to 0
). Then we don’t need to do an additional reverse to get our solution.
We can then implement solution
as a simple tail recursive helper function that peels one element off each input and multiplies them together. When we’re out of inputs, it returns its result:
productOfArrayExceptSelf :: V.Vector Int -> V.Vector Int
productOfArrayExceptSelf inputs = solution (prefixes, suffixes, [])
where
n = V.length inputs
-- Loop 2: Populate Results
solution :: ([Int], [Int], [Int]) -> V.Vector Int
solution ([], [], acc) = V.fromList acc
solution (p : ps, s : ss, acc) = solution (ps, ss, p * s : acc)
solution _ = error “Prefixes and suffixes must be the same size!”
prefixes :: [Int]
suffixes :: [Int]
So now we’ve done “Loop 2” already, and we just have to implement “Loop 1” so that it produces the right results. Again, we’ll make a tail recursive helper, and this will produce both prefixes
and suffixes
at once. It will take the index, as well as the “total” prefix and suffix so far, and then two accumulator lists. At the end of this, we want both lists in reverse index order.
productOfArrayExceptSelf :: V.Vector Int -> V.Vector Int
productOfArrayExceptSelf inputs = solution (prefixes, suffixes, [])
where
n = V.length inputs
-- Loop 2: Populate Results
solution :: ([Int], [Int], [Int]) -> V.Vector Int
prefixes :: [Int]
suffixes :: [Int]
(prefixes, suffixes) = mkPrefixSuffix (0, 1, [], 1, [])
-- Loop 1: Populate prefixes & suffixes
mkPrefixSuffix :: (Int, Int, [Int], Int, [Int]) -> ([Int], [Int])
mkPrefixSuffix (i, totalPre, pres, totalSuff, suffs) = undefined
Now we fill in mkPrefixSuffix
as we would any tail recursive helper. First we satisfy the base case. This occurs once i
is at least n
. We’ll return the accumulated lists.
mkPrefixSuffix :: (Int, Int, [Int], Int, [Int]) -> ([Int], [Int])
mkPrefixSuffix (i, totalPre, pres, totalSuff, suffs) = if i >= n then (pres, reverse suffs)
else ...
But observe we’ll need to reverse suffixes
! This becomes clear when we map out what each iteration of the loop looks like for a simple input. Doing this kind of “loop tracking” is a very helpful problem solving skill for walking through your code!
input = [3, 4, 5]
i = 0: (0, 1, [], 1, [])
i = 1: (1, 3, [1], 5, [1])
i = 2: (2, 12, [3, 1], 20, [5, 1])
i = 3: (3, 60, [12, 3, 1], 60, [20, 5, 1])
Our prefixes are [12, 3, 1]
, which is properly reversed, but the suffixes are [20, 5, 1]
. We don’t want both lists ending in 1! So we reverse the suffixes.
Now that we’ve figured this out, it’s simple enough to fill in the recursive case using what we already know from “Loop 1” in the Rust solution. We get the “front” index of input
with i
, and the “back” index with n - i - 1
, use these to get the new products, and then save the old products in our list.
mkPrefixSuffix :: (Int, Int, [Int], Int, [Int]) -> ([Int], [Int])
mkPrefixSuffix (i, totalPre, pres, totalSuff, suffs) = if i >= n then (pres, reverse suffs)
else
let nextPre = nums V.! i
nextSuff = nums V.! (n - i - 1)
in mkPrefixSuffix (i + 1, totalPre * nextPre, totalPre : pres, totalSuff * nextSuff, totalSuff : suffs)
Here’s our complete Haskell solution!
productOfArrayExceptSelf :: V.Vector Int -> V.Vector Int
productOfArrayExceptSelf inputs = solution (prefixes, suffixes, [])
where
n = V.length inputs
solution :: ([Int], [Int], [Int]) -> V.Vector Int
solution ([], [], acc) = V.fromList acc
solution (p : ps, s : ss, acc) = solution (ps, ss, p * s : acc)
solution _ = error "Invalid solution!"
prefixes :: [Int]
suffixes :: [Int]
(prefixes, suffixes) = mkPrefixSuffix (0, 1, [], 1, [])
mkPrefixSuffix:: (Int, Int, [Int], Int, [Int]) -> ([Int], [Int])
mkPrefixSuffix (i, totalPre, pres, totalSuff, suffs) = if i >= n then (pres, reverse suffs)
else
let nextPre = inputs V.! i
nextSuff = inputs V.! (n - i - 1)
in mkPrefixSuffix (i + 1, totalPre * nextPre, totalPre : pres, totalSuff * nextSuff, totalSuff : suffs)
Conclusion
In this comparison, we saw a couple important differences in problem solving with a loop-based language like Rust compared to Haskell.
- For-loops have to become recursion in Haskell
- We want to use lists in Haskell, not mutable vectors
- It takes a bit of planning to figure out when to reverse lists!
This led us to a couple important insights when solving problems in Haskell.
- “Starting from the end” can be very helpful in plotting out our solution
- “Loop tracking” is a very helpful skill to guide our solutions
For an in-depth look at these sorts of comparisons, check out our Solve.hs course. You’ll learn all the most important tips and tricks for solving coding problems in Haskell! In particular you’ll get an in-depth look at tail recursion, a vital concept for solving problems in Haskell.
Learning from Multiple Solution Approaches
Welcome to the second article in our Rust vs. Haskell problem solving series. Last week we saw some basic differences between Rust loops and Haskell recursion. We also saw how to use the concept of “folding” to simplify a recursive loop function.
This week, we’ll look at another simple problem and consider multiple solutions in each language. We’ll consider what a “basic” solution looks like, using relatively few library functions. Then we’ll consider more “advanced” solutions that make use of library functionality, and greatly simplify the structure of our solutions.
To learn more about problem solving in Haskell, including the importance of list library functions, take a look at our course Solve.hs! You’ll write most of Haskell’s list API from scratch so you get an in-depth understanding of the functions that are available!
The Problem
This week’s problem is Reverse Words in a String. The idea is simple. Our input is a string, which naturally has “words” separated by whitespace. We want to return a string that has all the words reversed! So if the input is ”A quick brown fox”
, the result should be ”fox brown quick A”
.
Notice that all whitespace is truncated in our output. We should only have a single space between words in our answer, with no leading or trailing whitespace.
The Algorithm
The algorithmic idea is simple and hardly needs explanation. We want to gather letters from the input word until we encounter whitespace. Then we append this buffered word to a growing result string, and keep following this process until we run out of input.
There is one wrinkle, which is whether we want to accumulate our answer in the forward or reverse direction. This changes across languages!
In Haskell, it’s actually more efficient to accumulate the “back” of our resulting string first, meaning we should start by iterating from the front of the input. This is more consistent with linked list construction.
In Rust, we’ll iterate from the back of the input so that we can accumulate our result from the “front”.
Basic Rust Solution
In our basic solution, we’re going to consider a character-by-character approach. As outlined in our algorithm, we can accomplish this task with a single loop, with two stateful values. First, we have the “current” word we’re accumulating of non-whitespace characters. Second, we have the final “result” we’re accumulating.
It’s efficient to append to the end of strings, meaning we want to construct our result from front-to-back. This means we’ll loop through the characters of our string in reverse, as shown with .rev()
here:
pub fn reverse_words(s: String) -> String {
let mut current = String::new();
let mut result = String::new();
for c in s.chars().rev() {
...
}
}
Within the loop, we now just have to consider what to do with each character. If the character is not whitespace, the answer is simple. We just append this character to our “current” word. Because we’re looping through the input in reverse, our “current” word will also be in reverse!
pub fn reverse_words(s: String) -> String {
let mut current = String::new();
let mut result = String::new();
for c in s.chars().rev() {
if !c.is_whitespace() {
current.push(c);
} else {
...
}
}
}
So what happens when we encounter whitespace? There’s a few conditions to consider:
- If “current” is empty, do nothing.
- If “result” is empty, append “current” (in reverse order) to result.
- If “result” is not empty, add a space and then append “current” in reverse.
- Regardless, clear “current” and prepare to gather a new string.
Here’s what the code looks like:
pub fn reverse_words(s: String) -> String {
let mut current = String::new();
let mut result = String::new();
for c in s.chars().rev() {
if !c.is_whitespace() {
current.push(c);
} else {
// Step 1: Skip if empty
if !current.is_empty() {
// Step 2/3 Only push an empty space is result is not empty
if !result.is_empty() {
result.push(' ');
}
// Step 2/3 Reverse current and append
for b in current.chars().rev() {
result.push(b);
}
// Step 4: Clear “current”
current.clear();
}
}
}
}
There’s one final trick. Unless the word begins with whitespace, we’ll still have non-empty current
at the end and we will not have appended it. So we do one final check, and once again append “current” in reverse order.
Here’s our final basic solution:
pub fn reverse_words(s: String) -> String {
let mut current = String::new();
let mut result = String::new();
for c in s.chars().rev() {
if !c.is_whitespace() {
current.push(c);
} else {
// Step 1: Skip if empty
if !current.is_empty() {
// Step 2/3 Only push an empty space is result is not empty
if !result.is_empty() {
result.push(' ');
}
// Step 2/3 Reverse current and append
for b in current.chars().rev() {
result.push(b);
}
// Step 4: Clear “current”
current.clear();
}
}
}
if !current.is_empty() {
if !result.is_empty() {
result.push(' ');
}
for b in current.chars().rev() {
result.push(b);
}
}
return result;
}
Advanced Rust Solution
Looping character-by-character is a bit cumbersome. However, since basic whitespace related operations are so common, there are some useful library functions for dealing with them.
Rust also prioritizes the ability to chain iterative operations together. This gives us the following one-line solution!
pub fn reverse_words(s: String) -> String {
s.split_whitespace().rev().collect::<Vec<&str>>().join(" ")
}
It has four stages:
- Split the input based on whitespace.
- Reverse the split-up words.
- Collect these words as a vector of strings.
- Join them together with one space in between them.
What is interesting about this structure is that each stage of the process has a separate type. Step 1 creates a SplitWhitespace
struct. Step 2 creates a Reverse
struct. Step 3 then creates a normal vector, and step 4 concludes by producing a string.
The two preliminary structures are essentially wrappers with iterators to help chain the operations together. As we’ll see, the comparable Haskell solution only uses basic lists, and this is a noteworthy difference between the languages.
Basic Haskell Solution
Our “basic” Haskell solution will follow the same outline as the basic Rust solution, but we’ll work in the opposite direction! We’ll loop through the input in forward order, and accumulate our output in reverse order.
Before we even get started though, we can make an observation from our basic Rust solution that we duplicated some code! The concept of combining the “current” word and the “result” had several edge cases to handle, so let’s write a combine
function to handle these.
-- “current” is reversed and then goes in *front* of result
-- (Rust version put “current” at the back)
combine :: (String, String) -> String
combine (current, res) = if null current then res
else reverse current <> if null res then "" else (' ' : res)
Now let’s think about our loop structure. We are going through the input, character-by-character. This means we should be able to use a fold, like we did last week! Whenever we’re using a fold, we want to think about the “state” we’re passing through each iteration. In our case, the state is the “current” word and the “result” string. This means our folding function should look like this:
loop :: (String, String) -> Char -> (String, String)
loop (current, result) c = ...
Now we just have to distinguish between the “whitespace” case and the non-whitespace case. If we encounter a space, we just combine the current word with the accumulated result. If we encounter a normal character, we append this to our current word (again, accumulating “current” in reverse).
loop :: (String, String) -> Char -> (String, String)
loop (currentWord, result) c = if isSpace c
then ("", combine (currentWord, result))
else (c : currentWord, result)
Now to complete the solution, we just call ‘foldl’ with our ‘loop’ and the input, and we just have to remember to combine the final “current” word with the output! Here’s our complete “basic” solution.
reverseWords :: String -> String
reverseWords input = combine $ foldl loop ("", "") input
where
combine :: (String, String) -> String
combine (current, res) = if null current then res
else reverse current <> if null res then "" else (' ' : res)
loop :: (String, String) -> Char -> (String, String)
loop (currentWord, result) c = if isSpace c
then ("", combine (currentWord, result))
else (c : currentWord, result)
Advanced Haskell Solutions
Now that we’ve seen a basic, character-by-character solution in Haskell, we can also consider more advanced solutions that incorporate library functions. The first improvement we can make is to lean on list functions like break
and dropWhile
.
Using break
splits off the first part of a list that does not satisfy a predicate. We’ll use this to gather non-space characters. Then dropWhile
allows us to drop the first series of characters in a list that satisfy a predicate. We’ll use this to get rid of whitespace as we move along!
So we’ll define this solution using a basic recursive loop rather than a fold, because each iteration will consume a variable number of characters. The “state” of this loop will be two strings: the remaining part of the input, and the accumulated result.
Since there’s no “current” word, our base case is easy. If the remaining input is empty, we return the accumulated result.
loop :: (String, String) -> String
loop ([], output) = output
...
Otherwise, we’ll follow this process:
- Separate the first “word” using
break isSpace
. - Combine this word with the output (if it’s not null)
- Recurse with the new output, dropping the initial whitespace from the remainder.
Here’s what it looks like:
loop :: (String, String) -> String
loop ([], output) = output
loop (cs, output) =
-- Step 1: Separate next word from rest
let (nextWord, rest) = L.break isSpace cs
-- Step 2: Make new output (account for edge cases)
-- (Can’t use ‘combine’ from above because we aren’t reversing!)
newOutput = if null output then nextWord
else if null nextWord then output
else nextWord <> (' ' : output)
-- Drop spaces from remainder and recurse
in loop (L.dropWhile isSpace rest, newOutput)
And completing the function is as simple as calling this loop with the base inputs:
reverseWords :: String -> String
reverseWords input = loop (input, “”)
The Simplest Haskell Solution
The final (and recommended) Haskell solution uses the library functions words
and unwords
. These do exactly what we want for this problem! We separate words based on whitespace using words
, and then join them with a single space with unwords
. All we have to do in between is reverse
.
reverseWords :: String -> String
reverseWords = unwords . reverse . words
This has a similar elegance to the advanced Rust solution, but is much simpler to understand since there are no complex structs or iterators involved. The types of all functions involved simply relate to lists. Here are the signatures, specialized to String
for this problem.
words :: String -> [String]
reverse :: [String] -> [String]
unwords :: [String] -> String
Conclusion
A simple problem will often have many solutions, but in this case, each of these solutions teaches us something new about the language we’re working with. Working character-by-character helps us understand some of the core mechanics of the language, showing us how it works under the hood. But using library functions helps us see the breadth of available options we have for simplifying future code we write.
In our Solve.hs course, you’ll go through all of these steps with Haskell. You’ll implement list library functions, data structures, and algorithms from scratch so you understand how they work under the hood. Then, you’ll know they exist and be able to apply them to efficiently solve harder problems. Take a look at the course today!