Binary Search in Haskell and Rust
This week we’ll be continuing our series of problem solving in Haskell and Rust. But now we’re going to start moving beyond the terrain of “basic” problem solving techniques with strings, lists and arrays, and start moving in the direction of more complicated data structures and algorithms. Today we’ll explore a problem that is still array-based, but uses a tricky algorithm that involves binary search!
You’ll learn more about Data Structures and Algorithms in our Solve.hs course! The last 7 weeks or so of blog articles have focused on the types of problems you’ll see in Module 1 of that course, but now we’re going to start encountering ideas from Modules 2 & 3, which look extensively at essential data structures and algorithms you need to know for problem solving.
The Problem
Today’s problem is median of two sorted arrays. In this problem, we receive two arrays of numbers as input, each of them in sorted order. The arrays are not necessarily of the same size. Our job is to find the median of the cumulative set of numbers.
Now there’s a conceptually easy approach to this. We could simply scan through the two arrays, keeping track of one index for each one. We would increase the index for whichever number is currently smaller, and stop once we have passed by half of the total numbers. This approach is essentially the “merge” part of merge sort, and it would take O(n) time, since we are scanning half of all the numbers.
However, there’s a faster approach! And if you are asked this question in an interview for anything other than a very junior position, your interviewer will expect you to find this faster approach. Because the arrays are sorted, we can leverage binary search to find the median in O(log n) time. The approach isn’t easy to see though! Let’s go over the algorithm before we get into any code.
The Algorithm
This algorithm is a little tricky to follow (this problem is rated as “hard” on LeetCode). So we’re going to treat this a bit like a mathematical proof, and begin by defining useful terms. Then it will be easy to describe the coding concepts behind the algorithm.
Defining our Terms
Our input consists of 2 arrays, arr1
and arr2
with potentially different sizes n
and m
, respectively. Without loss of generality, let arr1
be the “shorter” array, so that n <= m
. We’ll also define t
as the total number of elements, n + m
.
It is worthwhile to note right off the bat that if t
is odd, then a single element from one of the two lists will be the median. If t
is even, then we will average two elements together. Even though we won’t actually create the final merged array, we can imagine that it consists of 3 parts:
- The “prior” portion - all numbers before the median element(s)
- The median element(s), either 1 or 2.
- The “latter” portion - all numbers after the median element(s)
The total number of elements in the “prior” portion will end up being (t - 1) / 2
, bearing in mind how integer division works. For example, whether t
is 15 or 16, we get 7 elements in the “prior” portion. We’ll use p
for this number.
Finally, let’s imagine p1
, the number of elements from arr1
that will end up in the prior portion. If we know p1
, then p2
, the number of elements from arr2
in the prior portion is fixed, because p1 + p2 = p
. We can then think of p1
as an index into arr1
, the index of the first element that is not in the prior portion. The only trick is that this index could be n
indicating that all elements of arr1
are in the prior portion.
Getting the Final Answer from our Terms
If we have the “correct” values for p1
and p2
, then finding the median is easy. If t
is odd, then the lower number between arr1[p1]
and arr2[p2]
is the median. If t
is even, then we average the two smallest numbers among (arr1[p1], arr2[p2], arr1[p1 + 1], arr2[p2 + 1])
.
So we’ve reduced this problem to a matter of finding p1
, since p2
can be easily derived from it. How do we know we have the “correct” value for p1
, and how do we search for it efficiently?
Solving for p1
The answer is that we will conduct a binary search on arr1
in order to find the correct value of p1
. For any particular choice of p1
, we determine the corresponding value of p2
. Then we make two comparisons:
- Compare
arr1[p1 - 1]
toarr2[p2]
- Compare
arr2[p2 - 1]
toarr1[p1]
If both comparisons are less-than-or-equals, then our two p
values are correct! The slices arr1[0..p1-1]
and arr2[0..p2-1]
always constitute a total of p
values, and if these values are smaller than arr1[p1]
and arr2[p2]
, then they constitute the entire “prior” set.
If, on the other hand, the first comparison yields “greater than”, then we have too many values for arr1
in our prior set. This means we need to recursively do the binary search on the left side of arr1
, since p1
should be smaller.
Then if the second comparison yields “greater than”, we have too few values from arr1
in the “prior” set. We should increase p1
by searching the right half of our array.
This provides a complete algorithm for us to follow!
Rust Implementation
Our algorithm description was quite long, but the advantage of having so many details is that the code starts to write itself! We’ll start with our Rust implementation. Stage 1 is to define all of the terms using our input values. We want to define our sizes and array references generically so that arr1
is the shorter array:
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let mut n = nums1.len();
let mut m = nums2.len();
let mut arr1: &Vec<i32> = &nums1;
let mut arr2: &Vec<i32> = &nums2;
if (m < n) {
n = nums2.len();
m = nums1.len();
arr1 = &nums2;
arr2 = &nums1;
}
let t = n + m;
let p: usize = (t - 1) / 2;
...
}
Anatomy of a Binary Search
The next stage is the binary search, so we can find p1
and p2
. Now a binary search is a particular kind of loop pattern. Like many of the loop patterns we worked with in the previous weeks, we can express it recursively, or with a loop construct like for
or while
. We’ll start with a while
loop solution for Rust, and then show the recursive solution with Haskell.
All loops maintain some kind of state. For a binary search, the primary state is the two endpoints representing our “interval of interest”. This starts out as the entire interval, and shrinks by half each time until we’ve narrowed to a single element (or no elements). We’ll represent these with interval end points with low
and hi
. Our loop concludes once low
is as large as hi
.
let mut low = 0;
// Use the shorter array size!
let mut hi = n;
while (low < hi) {
...
}
In our particular case, we are also trying to determine the values for p1
and p2
. Each time we specify an interval, we’ll see if the midpoint of that interval (between low
and hi
) is the correct value of p1
:
...
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
...
}
Now we evaluate this p1
value using the two conditions we specified in our algorithm. These are self-explanatory, except we do need to cover some edge cases where one of our values is at the edge of the array bounds.
For example, if p1
is 0, the first condition is always “true”. If this condition is negated, this means we want fewer elements from arr1
, but this is impossible if p1
is 0.
...
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
if (cond1 && cond2) {
break;
} else if (!cond1) {
p1 -= 1;
hi = p1;
} else {
p1 += 1;
low = p1;
}
}
p2 = p - p1;
...
If both conditions are met, you’ll see we break
, because we’ve found the right value for p1
! Otherwise, we know p1
is invalid. This means we want to exclude the existing p1
value from further consideration by changing either low
or hi
to remove it from the interval of interest.
So if cond1
is false, hi
becomes p1 - 1
, and if cond2
is false, it becomes p1 + 1
. In both cases, we also modify p1
itself first so that our loop does not conclude with p1
in an invalid location.
Getting the Final Answer
Now that we have p1
and p2
, we have to do a couple final tricks to get the final answer. We want to get the first “smaller” value between arr1[p1]
and arr2[p2]
. But we have to handle the edge case where p1
might be n
AND we want to increment the index for the array we take. Note that p2
cannot be out of bounds right now!
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
If the total number of elements is odd, we can simply return this number (converting to a float). However, in the even case we need one more number to take an average. So we’ll compare the values at the indices again, but now accounting that either (but not both) could be out of bounds.
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
if (t % 2 == 0) {
if (p1 >= n) {
median += arr2[p2];
} else if (p2 >= m) {
median += arr1[p1];
} else {
median += cmp::min(arr1[p1], arr2[p2]);
}
let medianF: f64 = median.into();
return medianF / 2.0;
} else {
return median.into();
}
Here’s the complete implementation:
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let mut n = nums1.len();
let mut m = nums2.len();
let mut arr1: &Vec<i32> = &nums1;
let mut arr2: &Vec<i32> = &nums2;
if (m < n) {
n = nums2.len();
m = nums1.len();
arr1 = &nums2;
arr2 = &nums1;
}
let t = n + m;
let p: usize = (t - 1) / 2;
let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
p1 = (low + hi) / 2;
p2 = p - p1;
let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
if (cond1 && cond2) {
break;
} else if (!cond1) {
p1 -= 1;
hi = p1;
} else {
p1 += 1;
low = p1;
}
}
p2 = p - p1;
let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
median = arr1[p1];
p1 += 1;
} else {
p2 += 1;
}
if (t % 2 == 0) {
if (p1 >= n) {
median += arr2[p2];
} else if (p2 >= m) {
median += arr1[p1];
} else {
median += cmp::min(arr1[p1], arr2[p2]);
}
let medianF: f64 = median.into();
return medianF / 2.0;
} else {
return median.into();
}
}
Haskell Implementation
Now let’s examine the Haskell implementation. Unlike the LeetCode version, we’ll just assume our inputs are Double
already instead of doing a conversion. Once again, we start by defining the terms:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
n' = V.length input1
m' = V.length input2
t = n' + m'
p = (t - 1) `quot` 2
(n, m, arr1, arr2) = if V.length input1 <= V.length input2
then (n', m', input1, input2) else (m', n', input2, input1)
...
Now we’ll implement the binary search, this time doing a recursive function. We’ll do this in two parts, starting with a helper function. This helper function will simply tell us if a particular index is correct for p1
. The trick though is that we’ll return an Ordering
instead of just a Bool
:
-- data Ordering = LT | EQ | GT
f :: Int -> Ordering
This lets us signal 3 possibilities. If we return EQ
, this means the index is valid. If we return LT
, this will mean we want fewer values from arr1
. And then GT
means we want more values from arr1
.
With this framing it’s easy to see the implementation of this helper now. We determine the appropriate p2
, figure out our two conditions, and return the value for each condition:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
f :: Int -> Ordering
f pi1 =
let pi2 = p - pi1
cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
in if cond1 && cond2 then EQ else if (not cond1) then LT else GT
Now applying we can use this in a recursive binary search. The binary search tracks two pieces of state for our interval ((Int, Int)
), and it will return the correct value for p1
. The implementation applies the base case (return low
if low >= hi
), determines the midpoint, calls our helper, and then recurses appropriately based on the helper result.
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
f :: Int -> Ordering
f pi1 = ...
search :: (Int, Int) -> Int
search (low, hi) = if low >= hi then low else
let mid = (low + hi) `quot` 2
in case f mid of
EQ -> mid
LT -> search (low, mid - 1)
GT -> search (mid + 1, hi)
p1 = search (0, n)
p2 = p - p1
...
For the final part of the problem, we’ll define a helper. Given p1
and p2
, it will emit the “lower” value between the two indices in the array (accounting for edge cases) as well as the two new indices (since one will increment).
This is a matter of lazily defining the “next” value for each array, the “end” condition of each array, and the “result” if that array’s value is chosen:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
where
...
findNext pi1 pi2 =
let next1 = arr1 V.! pi1
next2 = arr2 V.! pi2
end1 = pi1 >= n
end2 = pi2 >= m
res1 = (next1, pi1 + 1, pi2)
res2 = (next2, pi1, pi2 + 1)
in if end1 then res2
else if end2 then res1
else if next1 <= next2 then res1 else res2
Now we just apply this either once or twice to get our result!
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
where
...
tIsEven = even t
(median1, nextP1, nextP2) = findNext p1 p2
(median2, _, _) = findNext nextP1 nextP2
result = if tIsEven
then (median1 + median2) / 2.0
else median1
Here’s the complete implementation:
medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
where
n' = V.length input1
m' = V.length input2
t = n' + m'
p = (t - 1) `quot` 2
(n, m, arr1, arr2) = if V.length input1 <= V.length input2
then (n', m', input1, input2) else (m', n', input2, input1)
-- Evaluate the index in arr1
-- If this does in indicate the index can be part of a median, return EQ
-- If it indicates we need to move left in shortArr, return LT
-- If it indicates we need to move right in shortArr, return GT
-- Precondition: p1 <= n
f :: Int -> Ordering
f pi1 =
let pi2 = p - pi1
cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
in if cond1 && cond2 then EQ else if (not cond1) then LT else GT
search :: (Int, Int) -> Int
search (low, hi) = if low >= hi then low else
let mid = (low + hi) `quot` 2
in case f mid of
EQ -> mid
LT -> search (low, mid - 1)
GT -> search (mid + 1, hi)
findNext pi1 pi2 =
let next1 = arr1 V.! pi1
next2 = arr2 V.! pi2
end1 = pi1 >= n
end2 = pi2 >= m
res1 = (next1, pi1 + 1, pi2)
res2 = (next2, pi1, pi2 + 1)
in if end1 then res2
else if end2 then res1
else if next1 <= next2 then res1 else res2
p1 = search (0, n)
p2 = p - p1
tIsEven = even t
(median1, nextP1, nextP2) = findNext p1 p2
(median2, _, _) = findNext nextP1 nextP2
result = if tIsEven
then (median1 + median2) / 2.0
else median1
Conclusion
If you want to learn more about these kinds of problem solving techniques, you should take our course Solve.hs! In the coming weeks, we’ll see more problems related to data structures and algorithms, which are covered extensively in Modules 2 and 3 of that course!