In-Order Traversal in Haskell and Rust

Last time around, we started exploring binary trees. We began with a simple problem (inverting a tree), but encountered some of the difficulties implementing a recursive data structure in Rust.

Today we’ll do a slightly harder problem (LeetCode rates as “Medium” instead of “Easy”). This problem is also specifically working with a binary search tree instead of a simple binary tree. With a search tree, we have the property that the “values” on each node are orderable, and all the values to the “left” of any given node are no greater than that node’s value, and the values to the “right” are not smaller.

Binary search trees are the heart of any ordered Set type. In our problem solving course Solve.hs, you’ll get the chance to build a self-balancing binary search tree from scratch, which involves some really cool algorithmic tricks!

The Problem

Though it’s harder than our previous problem, today’s problem is still straightforward. We are taking an ordered binary search tree and finding the k-th smallest element in that tree, where k is the second input to our function.

So suppose our tree looks like this:

-    45
    /  \
   32  50
  /  \   \
 5   40   100
    /  \
  37   43

Our input k is 1-indexed. So if we get 1 as our input, we should return 5, the smallest element in the tree. If we receive 4, we should return 40, the 4th smallest element after 5, 32 and 37. If we get 8, we’ll return 100, the largest element in the tree.

The Algorithm

Binary search trees are designed to give logarithmic time access, insertions, and deletions for elements. If our tree was annotated, so that each node stored the number of children it had, we’d be able to solve this problem in logarithmic time as well.

However, we want to assume a minimal tree design, where each node only holds its own value and pointers to its two children. With these constraints, our algorithm has to be linear in terms of the input k.

We’re going to solve this with an in-order traversal. We’re just going to traverse the elements of the BST in order from smallest value to largest, and count until we’ve encountered k elements. Then we’ll return the value at our “current” node.

An in-order traversal is conceptually simple. For a given node, we “visit” the left child, then visit the node itself, and then visit the right child. But the actual mechanics of doing this traversal can be a little tricky to think of on the spot if you haven’t practiced it before.

The main idea is that we’ll use a stack of nodes to track where we are in the tree. The stack traces a path from our “current” node back up its parents back to the root of the tree. Our algorithm looks like this:

First, we create a stack from the root, following all left child nodes until we reach a node without a left child. This node has the smallest element of the tree.

Now, we begin processing, always considering the top node in the stack, while tracking the number of elements remaining until we hit k. If that number is down to 1, then the value at the top node in the stack is our result.

If not, we’ll decrement the number of elements remaining, and then check the right child of this node. If the right child is Nil, we just pop this node from the stack and process its parent. If the right child does exist, we’ll add all of its left children to our stack as well.

Here’s how the stack looks like with our example tree, if we’re looking for k=7 (which should be 50).

[5, 32, 45] -- k = 1, Initial left children of root
[32, 45] -- k = 2, popped 5, no right child
[37, 40, 45] -- k = 3, popped 32, right child was 40, which added left child 37
[40, 45] -- k = 4, popped 37, no right child
[43, 45] -- k = 5, popped 40, and 43 is right child
[45] -- k = 6, popped 43
[50] -- k = 7, popped 45 and added 50, the right child (no left children)

Since 50 is on top of the stack with k = 7, we can return 50.

Haskell Solution

Let’s code this up! We’ll start with Haskell, since Rust is, once again, somewhat tricky due to TreeNode handling. To start, let’s remind ourselves of the recursive TreeNode type:

data TreeNode = Nil | Node Int TreeNode TreeNode
  deriving (Show, Eq)

Now when writing up the algorithm, we first want to define a helper function addLeftNodesToStack. We’ll use this at the beginning of the algorithm, and then again each time we encounter a right child.

This helper will take a TreeNode and the existing stack, and return the modified stack.

kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
  where
    addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
    addLeftNodesToStack = ...

As far as recursive helpers, this is a simple one! If our input node is Nil, we return the original stack. We want to maintain the invariant that we never include Nil values in our stack! But if we have a value node, we just add it to the stack and recurse on its left child.

kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = ...
  where
    addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
    addLeftNodesToStack Nil acc = acc
    addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)

Now it’s time to implement our algorithm for finding the k-th element. This will be a recursive function that takes the number of elements remaining, as well as the current stack. We’ll call this initially with k and the stack we get from adding the left nodes of the root:

kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
  where
    addLeftNodesToStack = ...

    findK :: Int -> [TreeNode] -> Int
    findK = ...

This function has a couple error cases. We expect a non-empty stack (our input k is constrained within the size of the tree), and we expect the top to be non-Nil. After that, we have our base case where k = 1, and we return the value at this node.

Finally, we get our recursive case. We decrement the remaining count, and add the left nodes of the right child of this node to the stack.

kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
  where
    addLeftNodesToStack = ...

    findK :: Int -> [TreeNode] -> Int
    findK k [] = error $ "Found empty list expecting k: " ++ show k
    findK _ (Nil : _) = error "Added Nil to stack!"
    findK 1 (Node x _ _ : _) = x
    findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)

This completes our solution!

kthSmallest :: TreeNode -> Int -> Int
kthSmallest root' k' = findK k' (addLeftNodesToStack root' [])
  where
    addLeftNodesToStack :: TreeNode -> [TreeNode] -> [TreeNode]
    addLeftNodesToStack Nil acc = acc
    addLeftNodesToStack root@(Node _ left _) acc = addLeftNodesToStack left (root : acc)

    findK :: Int -> [TreeNode] -> Int
    findK k [] = error $ "Found empty list expecting k: " ++ show k
    findK _ (Nil : _) = error "Added Nil to stack!"
    findK 1 (Node x _ _ : _) = x
    findK k (Node _ _ right : rest) = findK (k - 1) (addLeftNodesToStack right rest)

Rust Solution

In our Rust solution, we’re once again working with this TreeNode type, including the 3 wrapper layers:

#[derive(Debug, PartialEq, Eq)]
pub struct TreeNode {
  pub val: i32,
  pub left: Option<Rc<RefCell<TreeNode>>>,
  pub right: Option<Rc<RefCell<TreeNode>>>,
}

Our first step will be to implement the helper function to add the “left” nodes. This function will take a “root” node as well as a mutable reference to the stack so we can add nodes to it.

fn add_left_nodes_to_stack(
        node: Option<Rc<RefCell<TreeNode>>>,
        stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
    ...
}

You’ll notice that stack does not actually use the Option wrapper, only Rc and RefCell. Remember in our Haskell solution that we want to enforce that we don’t add non-null nodes to the stack. This Rust solution enforces this constraint at compile time.

To implement this function, we’ll use the same trick we did when inverting trees to pattern match on node and detect if it is Some or None. If it is None, we don’t have to do anything.

fn add_left_nodes_to_stack(
        node: Option<Rc<RefCell<TreeNode>>>,
        stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
    if let Some(current) = node {
        ...
    }
}

Since current is now unwrapped from Option, we can push it to the stack. As in our previous problem though, we have to clone it first! We need a clone of the reference (as wrapped by Rc because the stack will now have to own this reference.

fn add_left_nodes_to_stack(
        node: Option<Rc<RefCell<TreeNode>>>,
        stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
    if let Some(current) = node {
        stack.push(current.clone());
        ...
    }
}

Now we’ll recurse on the left subchild of current. In order to unwrap the TreeNode from Rc/RefCell, we have to use borrow. Then we can grab the left value. But again, we have to clone it before we make the recursive call. Here’s the final implementation of this helper:

fn add_left_nodes_to_stack(
        node: Option<Rc<RefCell<TreeNode>>>,
        stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
    if let Some(current) = node {
        stack.push(current.clone());
        add_left_nodes_to_stack(current.borrow().left.clone(), stack);
    }
}

We could have implemented the helper with a while loop instead of recursion. This would actually have used less memory in Rust! We would have to make some changes though, like making a new mut reference from the root.

Now we can move on to the core function. We’ll start this by defining key terms like our stack and the number of “remaining” values (initially k). We’ll also call our helper to get the initial stack.

pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
    let mut stack = Vec::new();
    let mut remaining = k;

    add_left_nodes_to_stack(root, &mut stack);
    ...
}

Now we want to pop the top element from the stack, and pattern match it as requiring Some. If there are no more values, we’ll actually panic, because the problem constraints should mean that our stack is never empty. Unlike our helper, we actually will use a while loop here instead of more recursion:

pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
    let mut stack = Vec::new();
    let mut remaining = k;

    add_left_nodes_to_stack(root, &mut stack);

    while let Some(current) = stack.pop() {
        ...
    }
    panic!("k is larger than number of nodes");
}

Now the inside of the loop is simple, following what we’ve done in Haskell. If our remainder is 1, then we have found the correct node. We borrow the node from the RefCell and return its value. Otherwise we decrement the count and use our helper on the “right” child of the node we just popped. As usual, the RefCell wrapper means we need to borrow to get the right value from the TreeNode, and then we clone this child as we pass it to the helper.

pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
    let mut stack = Vec::new();
    let mut remaining = k;

    add_left_nodes_to_stack(root, &mut stack);

    while let Some(current) = stack.pop() {
        if remaining == 1 {
            return current.borrow().val;
        }
        remaining -= 1;
        add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
    }
    panic!("k is larger than number of nodes");
}

And that’s it! Here’s the complete Rust solution:

fn add_left_nodes_to_stack(
        node: Option<Rc<RefCell<TreeNode>>>,
        stack: &mut Vec<Rc<RefCell<TreeNode>>>,
) {
    if let Some(current) = node {
        stack.push(current.clone());
        add_left_nodes_to_stack(current.borrow().left.clone(), stack);
    }
}

pub fn kth_smallest(root: Option<Rc<RefCell<TreeNode>>>, k: i32) -> i32 {
    let mut stack = Vec::new();
    let mut remaining = k;

    add_left_nodes_to_stack(root, &mut stack);

    while let Some(current) = stack.pop() {
        if remaining == 1 {
            return current.borrow().val;
        }
        remaining -= 1;
        add_left_nodes_to_stack(current.borrow().right.clone(), &mut stack);
    }
    panic!("k is larger than number of nodes");
}

Conclusion

In-order traversal is a great pattern to commit to memory, as many different tree problems will require you to apply it. Hopefully the details with Rust’s RefCells are getting more familiar. Next week we’ll do one more problem with binary trees.

If you want to do some deep work with Haskell and binary trees, take a look at our Solve.hs course, where you’ll learn about many different data structures in Haskell, and get the chance to write a balanced binary search tree from scratch!

Previous
Previous

Binary Tree BFS: Zigzag Order

Next
Next

An Easy Problem Made Hard: Rust & Binary Trees