An Easy Problem Made Hard: Rust & Binary Trees
In last week’s article, we completed our look at Matrix-based problems. Today, we’re going to start considering another data structure: binary trees.
Binary trees are an extremely important structure in programming. Most notably, they are the underlying structure for ordered sets, that allow logarithmic time lookups and insertions. A “tree” is represented by “nodes”, where a node can be “null”, or else hold a value. If it holds a value, it then has a “left” child and a “right” child.
If you take our Solve.hs course, you’ll actually learn to implement an auto-balancing ordered tree set from scratch!
But for these next few articles, we’re going to explore some simple problems that involve binary trees. Today we’ll start with a problem that is very simple (rated as “Easy” by LeetCode), but still helps us grasp the core problem solving techniques behind binary trees. We’ll also encounter some interesting curveballs that Rust can throw at us when it comes to building more complex data structures.
The Problem
Our problem today is Invert Binary Tree. Given a binary tree, we want to return a new tree that is the mirror image of the input tree. For example, if we get this tree as an input:
- 45
/ \
32 50
/ \ \
5 40 100
/ \
37 43
We should output a tree that looks like this:
- 45
/ \
50 32
/ / \
100 40 5
/ \
43 37
We see that 45 remains the root element, but instead of having 32 on the left and 50 on the right, these two elements are reversed on the next level. Then on the 3rd level, 40 and 5 remain children of 32, but they are also reversed from their prior orientations! This pattern continues all the way down on both sides.
The Algorithm
Binary trees (and in fact, all tree structures) lend themselves very well to recursive algorithms. We’ll use a very simple recursive algorithm here.
If the input node to our function is a “null” node, we will simply return “null” as the output. If the node has a value and children, then we’ll keep that value as the value for our output. However, we will recursively invert both of the child nodes.
Then we’ll take the inverted “left” child node and install it as the “right” child of our result. The inverted “right” child of the original input becomes the “left” child of the resulting node.
Haskell Solution
Haskell is a natural fit for this problem, since it relies so heavily on recursion. We start by defining a recursive TreeNode
type. The canonical way to do this is with a Nil
constructor as well as a recursive “value” constructor that actually holds the node’s value and refers to the left and right child. For this problem, we’ll just assume our tree holds Int
values, so we won’t parameterize it.
data TreeNode = Nil | Node Int TreeNode TreeNode
deriving (Show, Eq)
Now solving our problem is easy! We pattern match on the input TreeNode
. For our first case, we just return Nil
for Nil
.
invertTree :: TreeNode -> TreeNode
invertTree Nil = Nil
invertTree (Node x left right) = ...
For our second case, we use the same Int
value for the value of our result. Then we recursively call invertTree
on the right
child, but put this in the place of the left child for our new result node. Likewise, we recursively invert the left
child of our original and use this result for the right of our result.
invertTree :: TreeNode -> TreeNode
invertTree Nil = Nil
invertTree (Node x left right) = Node x (invertTree right) (invertTree left)
Very easy!
C++ Solution
In a non-functional language, it is still quite possible to solve this problem without recursion, but this is an occasion where we get very nice, clean code with recursion. As a rare treat, we’ll actually start with a C++ solution instead of jumping to Rust right away.
We would start by defining our TreeNode
with a struct
. Instead of relying on a separate Nil
constructor, we use raw pointers for all our tree nodes. This means they can all potentially be nullptr
.
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int v, TreeNode* l, TreeNode* r) : val(v), left(l), right(r) {};
};
TreeNode* invertTree(TreeNode* root) {
...
}
And our solution looks almost as easy as the Haskell solution:
TreeNode* invertTree(TreeNode* root) {
if (root == nullptr) {
return nullptr;
}
return new TreeNode(root->val, invertTree(root->right), invertTree(root->left));
}
Rust Solution
In Rust, it’s not quite as easy to work with recursive structures because of Rust’s memory system. In C++, we used raw pointers, which is fast but can cause significant problems if you aren’t careful (e.g. dereferencing null pointers, or memory leaks). Haskell uses garbage collected memory, which is slow but allows us to write simple code that won’t blow up in weird ways like C++.
Rust’s Memory System
Rust seeks to be fast like C++, while making it hard to do high-risk things like de-referencing a potentially null pointer, or leaking memory. It does this using the concept of “ownership”, and it’s a tricky concept to understand at first.
The ownership model makes it a bit harder for us to write a basic recursive data structures. To write a basic binary tree, you’d have to answer questions like:
- Who “owns” the child nodes?
- Can I write a function that accesses the child nodes without taking ownership of them? What if I have to modify them?
- Can I copy a reference to a child node without copying the entire sub-structure?
- Can I create a “new” tree that references part of another tree without copying?
Writing a TreeNode
Here’s the TreeNode
struct provided by LeetCode for solving this problem. We can see that references to the nodes themselves are held within 3(!) wrapper types:
#[derive(Debug, PartialEq, Eq)]
pub struct TreeNode {
pub val: i32,
pub left: Option<Rc<RefCell<TreeNode>>>,
pub right: Option<Rc<RefCell<TreeNode>>>,
}
impl TreeNode {
#[inline]
pub fn new(val: i32) -> Self {
TreeNode {
val,
left: None,
right: None
}
}
}
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
}
From inside to outside, here’s what the three wrappers mean:
RefCell
is a mutable, shareable container for data.Rc
is a reference counting container. It automatically tracks how many references there are to theRefCell
. The cell is de-allocated once this count is 0.Option
is Rust’s equivalent ofMaybe
. This let’s us useNone
for an empty tree.
Rust normally only permits a single mutable reference, or multiple immutable references. So RefCell
provides mechanics to get multiple mutable references. Let’s see how we can use these to write our invert_tree
function.
Solving the Problem
We start by “cloning” the root
input reference. Normally, “clone” means a deep copy, but in our case, this doesn’t actually copy the entire tree! Because it is wrapped in Rc
, we’re just getting a new reference to the data in RefCell
. We conditionally check if this is a Some
wrapper. If it is None
, we just return the root
.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
...
}
return root;
}
If we didn’t “clone” root, the compiler would complain that we are “moving” the value in the condition, which would invalidate the prior reference to root
.
Next, we use borrow_mut
to get a mutable reference to the TreeNode
inside the RefCell
. This node_ref
finally gives us something of type TreeNode
so that we can work with the individual fields.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
...
}
return root;
}
Now for node_ref
, both left
and right
have the full wrapper type Option<Rc<RefCell<TreeNode>>>
. We want to recursively call invert_tree
on these. Once again though, we have to call clone
before passing these to the recursive function.
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
// Recursively invert left and right subtrees
let left = invert_tree(node_ref.left.clone());
let right = invert_tree(node_ref.right.clone());
...
}
return root;
}
Now because we have a mutable reference in node_ref
, we can install these new results as its left
and right
subtrees!
pub fn invert_tree(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
if let Some(node) = root.clone() {
let mut node_ref = node.borrow_mut();
// Recursively invert left and right subtrees
let left = invert_tree(node_ref.left.clone());
let right = invert_tree(node_ref.right.clone());
// Swap them
node_ref.left = right;
node_ref.right = left;
}
return root;
}
And now we’re done! We don’t need a separate return
statement inside the if
. We have modified node_ref
, which is still a reference to the same data as root
holds. So returning root
returns our modified tree.
Conclusion
Even though this was a simple problem with a basic recursive algorithm, we saw how Rust presented some interesting difficulties in applying this algorithm. Languages all make different tradeoffs, so every language has some example where it is difficult to write code that is simple in other languages. For Rust, this is recursive data structures. For Haskell though, it’s things like mutable arrays.
If you want to get some serious practice with binary trees, you should sign up for our problem solving course, Solve.hs. In Module 2, you’ll actually get to implement a balanced tree set from scratch, which is a very interesting and challenging problem that will stretch your knowledge!