In-Order Traversal in Haskell and Rust
Last time around, we started exploring binary trees. We began with a simple problem (inverting a tree), but encountered some of the difficulties implementing a recursive data structure in Rust.
Today we’ll do a slightly harder problem (LeetCode rates as “Medium” instead of “Easy”). This problem is also specifically working with a binary search tree instead of a simple binary tree. With a search tree, we have the property that the “values” on each node are orderable, and all the values to the “left” of any given node are no greater than that node’s value, and the values to the “right” are not smaller.
Binary search trees are the heart of any ordered Set
type. In our problem solving course Solve.hs, you’ll get the chance to build a self-balancing binary search tree from scratch, which involves some really cool algorithmic tricks!
The Problem
Though it’s harder than our previous problem, today’s problem is still straightforward. We are taking an ordered binary search tree and finding the k-th smallest element in that tree, where k
is the second input to our function.
So suppose our tree looks like this:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
Our input k
is 1-indexed. So if we get 1
as our input, we should return 5
, the smallest element in the tree. If we receive 4
, we should return 40
, the 4th smallest element after 5, 32 and 37. If we get 8
, we’ll return 100
, the largest element in the tree.
The Algorithm
Binary search trees are designed to give logarithmic time access, insertions, and deletions for elements. If our tree was annotated, so that each node stored the number of children it had, we’d be able to solve this problem in logarithmic time as well.
However, we want to assume a minimal tree design, where each node only holds its own value and pointers to its two children. With these constraints, our algorithm has to be linear in terms of the input k
.
We’re going to solve this with an in-order traversal. We’re just going to traverse the elements of the BST in order from smallest value to largest, and count until we’ve encountered k
elements. Then we’ll return the value at our “current” node.
An in-order traversal is conceptually simple. For a given node, we “visit” the left child, then visit the node itself, and then visit the right child. But the actual mechanics of doing this traversal can be a little tricky to think of on the spot if you haven’t practiced it before.
The main idea is that we’ll use a stack of nodes to track where we are in the tree. The stack traces a path from our “current” node back up its parents back to the root of the tree. Our algorithm looks like this:
First, we create a stack from the root, following all left child nodes until we reach a node without a left child. This node has the smallest element of the tree.
Now, we begin processing, always considering the top node in the stack, while tracking the number of elements remaining until we hit k
. If that number is down to 1
, then the value at the top node in the stack is our result.
If not, we’ll decrement the number of elements remaining, and then check the right child of this node. If the right child is Nil
, we just pop this node from the stack and process its parent. If the right child does exist, we’ll add all of its left children to our stack as well.
Here’s how the stack looks like with our example tree, if we’re looking for k=7
(which should be 50).
[5, 32, 45] -- k = 1, Initial left children of root
[32, 45] -- k = 2, popped 5, no right child
[37, 40, 45] -- k = 3, popped 32, right child was 40, which added left child 37
[40, 45] -- k = 4, popped 37, no right child
[43, 45] -- k = 5, popped 40, and 43 is right child
[45] -- k = 6, popped 43
[50] -- k = 7, popped 45 and added 50, the right child (no left children)
Since 50 is on top of the stack with k = 7
, we can return 50
.
Haskell Solution
Let’s code this up! We’ll start with Haskell, since Rust is, once again, somewhat tricky due to TreeNode
handling. To start, let’s remind ourselves of the recursive TreeNode
type:
data TreeNode = Nil | Node Int TreeNode TreeNode
deriving (Show, Eq)
Now when writing up the algorithm, we first want to define a helper function addLeftNodesToStack
. We’ll use this at the beginning of the algorithm, and then again each time we encounter a right child.
This helper will take a TreeNode
and the existing stack, and return the modified stack.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack = ...
As far as recursive helpers, this is a simple one! If our input node is Nil
, we return the original stack. We want to maintain the invariant that we never include Nil
values in our stack! But if we have a value node, we just add it to the stack and recurse on its left child.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack Nil acc = acc
addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)
Now it’s time to implement our algorithm for finding the k-th element. This will be a recursive function that takes the number of elements remaining, as well as the current stack. We’ll call this initially with k
and the stack we get from adding the left nodes of the root:
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack = ...
findK :: Int -> [TreeNode] -> Int
findK = ...
This function has a couple error cases. We expect a non-empty stack (our input k
is constrained within the size of the tree), and we expect the top to be non-Nil. After that, we have our base case where k = 1
, and we return the value at this node.
Finally, we get our recursive case. We decrement the remaining count, and add the left nodes of the right child of this node to the stack.
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack = ...
findK :: Int -> [TreeNode] -> Int
findK k [] = error $ "Found empty list expecting k: " ++ show k
findK _ (Nil : _) = error "Added Nil to stack!"
findK 1 (Node x _ _ : _) = x
findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)
This completes our solution!
kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
where
addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
addLeftNodesToStack Nil acc = acc
addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)
findK :: Int -> [TreeNode] -> Int
findK k [] = error $ "Found empty list expecting k: " ++ show k
findK _ (Nil : _) = error "Added Nil to stack!"
findK 1 (Node x _ _ : _) = x
findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)
Rust Solution
In our Rust solution, we’re once again working with this TreeNode
type, including the 3 wrapper layers:
#[derive(Debug, PartialEq, Eq)]
pub struct TreeNode {
pub val: i32,
pub left: Option<Rc<RefCell<TreeNode>>>,
pub right: Option<Rc<RefCell<TreeNode>>>,
}
Our first step will be to implement the helper function to add the “left” nodes. This function will take a “root” node as well as a mutable reference to the stack so we can add nodes to it.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
...
}
You’ll notice that stack
does not actually use the Option
wrapper, only Rc
and RefCell
. Remember in our Haskell solution that we want to enforce that we don’t add non-null nodes to the stack. This Rust solution enforces this constraint at compile time.
To implement this function, we’ll use the same trick we did when inverting trees to pattern match on node
and detect if it is Some
or None
. If it is None
, we don’t have to do anything.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
...
}
}
Since current
is now unwrapped from Option
, we can push it to the stack
. As in our previous problem though, we have to clone
it first! We need a clone of the reference (as wrapped by Rc
because the stack
will now have to own this reference.
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
...
}
}
Now we’ll recurse on the left
subchild of current
. In order to unwrap the TreeNode
from Rc/RefCell
, we have to use borrow
. Then we can grab the left
value. But again, we have to clone
it before we make the recursive call. Here’s the final implementation of this helper:
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
add_left_nodes_to_stack(current.borrow().left.clone(), stack);
}
}
We could have implemented the helper with a while
loop instead of recursion. This would actually have used less memory in Rust! We would have to make some changes though, like making a new mut
reference from the root.
Now we can move on to the core function. We’ll start this by defining key terms like our stack
and the number of “remaining” values (initially k
). We’ll also call our helper to get the initial stack.
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
...
}
Now we want to pop
the top element from the stack, and pattern match it as requiring Some
. If there are no more values, we’ll actually panic, because the problem constraints should mean that our stack is never empty. Unlike our helper, we actually will use a while
loop here instead of more recursion:
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
...
}
panic!("k is larger than number of nodes");
}
Now the inside of the loop is simple, following what we’ve done in Haskell. If our remainder is 1, then we have found the correct node. We borrow
the node from the RefCell
and return its value. Otherwise we decrement the count and use our helper on the “right” child of the node we just popped. As usual, the RefCell
wrapper means we need to borrow to get the right
value from the TreeNode
, and then we clone
this child as we pass it to the helper.
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
if remaining == 1 {
return current.borrow().val;
}
remaining -= 1;
add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
}
panic!("k is larger than number of nodes");
}
And that’s it! Here’s the complete Rust solution:
fn add_left_nodes_to_stack(
node: Option<Rc<RefCell<TreeNode>>>,
stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
if let Some(current) = node {
stack.push(current.clone());
add_left_nodes_to_stack(current.borrow().left.clone(), stack);
}
}
pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
let mut stack = Vec::new();
let mut remaining = k;
add_left_nodes_to_stack(root, &mut stack);
while let Some(current) = stack.pop() {
if remaining == 1 {
return current.borrow().val;
}
remaining -= 1;
add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
}
panic!("k is larger than number of nodes");
}
Conclusion
In-order traversal is a great pattern to commit to memory, as many different tree problems will require you to apply it. Hopefully the details with Rust’s RefCells
are getting more familiar. Next week we’ll do one more problem with binary trees.
If you want to do some deep work with Haskell and binary trees, take a look at our Solve.hs course, where you’ll learn about many different data structures in Haskell, and get the chance to write a balanced binary search tree from scratch!