Binary Search in Haskell and Rust

This week we’ll be continuing our series of problem solving in Haskell and Rust. But now we’re going to start moving beyond the terrain of “basic” problem solving techniques with strings, lists and arrays, and start moving in the direction of more complicated data structures and algorithms. Today we’ll explore a problem that is still array-based, but uses a tricky algorithm that involves binary search!

You’ll learn more about Data Structures and Algorithms in our Solve.hs course! The last 7 weeks or so of blog articles have focused on the types of problems you’ll see in Module 1 of that course, but now we’re going to start encountering ideas from Modules 2 & 3, which look extensively at essential data structures and algorithms you need to know for problem solving.

The Problem

Today’s problem is median of two sorted arrays. In this problem, we receive two arrays of numbers as input, each of them in sorted order. The arrays are not necessarily of the same size. Our job is to find the median of the cumulative set of numbers.

Now there’s a conceptually easy approach to this. We could simply scan through the two arrays, keeping track of one index for each one. We would increase the index for whichever number is currently smaller, and stop once we have passed by half of the total numbers. This approach is essentially the “merge” part of merge sort, and it would take O(n) time, since we are scanning half of all the numbers.

However, there’s a faster approach! And if you are asked this question in an interview for anything other than a very junior position, your interviewer will expect you to find this faster approach. Because the arrays are sorted, we can leverage binary search to find the median in O(log n) time. The approach isn’t easy to see though! Let’s go over the algorithm before we get into any code.

The Algorithm

This algorithm is a little tricky to follow (this problem is rated as “hard” on LeetCode). So we’re going to treat this a bit like a mathematical proof, and begin by defining useful terms. Then it will be easy to describe the coding concepts behind the algorithm.

Defining our Terms

Our input consists of 2 arrays, arr1 and arr2 with potentially different sizes n and m, respectively. Without loss of generality, let arr1 be the “shorter” array, so that n <= m. We’ll also define t as the total number of elements, n + m.

It is worthwhile to note right off the bat that if t is odd, then a single element from one of the two lists will be the median. If t is even, then we will average two elements together. Even though we won’t actually create the final merged array, we can imagine that it consists of 3 parts:

  1. The “prior” portion - all numbers before the median element(s)
  2. The median element(s), either 1 or 2.
  3. The “latter” portion - all numbers after the median element(s)

The total number of elements in the “prior” portion will end up being (t - 1) / 2, bearing in mind how integer division works. For example, whether t is 15 or 16, we get 7 elements in the “prior” portion. We’ll use p for this number.

Finally, let’s imagine p1, the number of elements from arr1 that will end up in the prior portion. If we know p1, then p2, the number of elements from arr2 in the prior portion is fixed, because p1 + p2 = p. We can then think of p1 as an index into arr1, the index of the first element that is not in the prior portion. The only trick is that this index could be n indicating that all elements of arr1 are in the prior portion.

Getting the Final Answer from our Terms

If we have the “correct” values for p1 and p2, then finding the median is easy. If t is odd, then the lower number between arr1[p1] and arr2[p2] is the median. If t is even, then we average the two smallest numbers among (arr1[p1], arr2[p2], arr1[p1 + 1], arr2[p2 + 1]).

So we’ve reduced this problem to a matter of finding p1, since p2 can be easily derived from it. How do we know we have the “correct” value for p1, and how do we search for it efficiently?

Solving for p1

The answer is that we will conduct a binary search on arr1 in order to find the correct value of p1. For any particular choice of p1, we determine the corresponding value of p2. Then we make two comparisons:

  1. Compare arr1[p1 - 1] to arr2[p2]
  2. Compare arr2[p2 - 1] to arr1[p1]

If both comparisons are less-than-or-equals, then our two p values are correct! The slices arr1[0..p1-1] and arr2[0..p2-1] always constitute a total of p values, and if these values are smaller than arr1[p1] and arr2[p2], then they constitute the entire “prior” set.

If, on the other hand, the first comparison yields “greater than”, then we have too many values for arr1 in our prior set. This means we need to recursively do the binary search on the left side of arr1, since p1 should be smaller.

Then if the second comparison yields “greater than”, we have too few values from arr1 in the “prior” set. We should increase p1 by searching the right half of our array.

This provides a complete algorithm for us to follow!

Rust Implementation

Our algorithm description was quite long, but the advantage of having so many details is that the code starts to write itself! We’ll start with our Rust implementation. Stage 1 is to define all of the terms using our input values. We want to define our sizes and array references generically so that arr1 is the shorter array:

pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
    let mut n = nums1.len();
    let mut m = nums2.len();
    let mut arr1: &Vec<i32> = &nums1;
    let mut arr2: &Vec<i32> = &nums2;
    if (m < n) {
        n = nums2.len();
        m = nums1.len();
        arr1 = &nums2;
        arr2 = &nums1;
    }
    let t = n + m;
    let p: usize = (t - 1) / 2;

    ...
}

Anatomy of a Binary Search

The next stage is the binary search, so we can find p1 and p2. Now a binary search is a particular kind of loop pattern. Like many of the loop patterns we worked with in the previous weeks, we can express it recursively, or with a loop construct like for or while. We’ll start with a while loop solution for Rust, and then show the recursive solution with Haskell.

All loops maintain some kind of state. For a binary search, the primary state is the two endpoints representing our “interval of interest”. This starts out as the entire interval, and shrinks by half each time until we’ve narrowed to a single element (or no elements). We’ll represent these with interval end points with low and hi. Our loop concludes once low is as large as hi.

let mut low = 0;
// Use the shorter array size!
let mut hi = n;
while (low < hi) {
    ...
}

In our particular case, we are also trying to determine the values for p1 and p2. Each time we specify an interval, we’ll see if the midpoint of that interval (between low and hi) is the correct value of p1:

...

let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
    p1 = (low + hi) / 2;
    p2 = p - p1;
    ...
}

Now we evaluate this p1 value using the two conditions we specified in our algorithm. These are self-explanatory, except we do need to cover some edge cases where one of our values is at the edge of the array bounds.

For example, if p1 is 0, the first condition is always “true”. If this condition is negated, this means we want fewer elements from arr1, but this is impossible if p1 is 0.

...

let mut low = 0;
let mut hi = n;
let mut p1 = 0;
let mut p2 = 0;
while (low < hi) {
    p1 = (low + hi) / 2;
    p2 = p - p1;
    let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
    let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
    if (cond1 && cond2) {
        break;
    } else if (!cond1) {
        p1 -= 1;
        hi = p1;
    } else {
        p1 += 1;
        low = p1;
    }
}
p2 = p - p1;

...

If both conditions are met, you’ll see we break, because we’ve found the right value for p1! Otherwise, we know p1 is invalid. This means we want to exclude the existing p1 value from further consideration by changing either low or hi to remove it from the interval of interest.

So if cond1 is false, hi becomes p1 - 1, and if cond2 is false, it becomes p1 + 1. In both cases, we also modify p1 itself first so that our loop does not conclude with p1 in an invalid location.

Getting the Final Answer

Now that we have p1 and p2, we have to do a couple final tricks to get the final answer. We want to get the first “smaller” value between arr1[p1] and arr2[p2]. But we have to handle the edge case where p1 might be n AND we want to increment the index for the array we take. Note that p2 cannot be out of bounds right now!

let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
    median = arr1[p1];
    p1 += 1;
} else {
    p2 += 1;
}

If the total number of elements is odd, we can simply return this number (converting to a float). However, in the even case we need one more number to take an average. So we’ll compare the values at the indices again, but now accounting that either (but not both) could be out of bounds.

let mut median = arr2[p2];
if (p1 < n && arr1[p1] < arr2[p2]) {
    median = arr1[p1];
    p1 += 1;
} else {
    p2 += 1;
}

if (t % 2 == 0) {
    if (p1 >= n) {
        median += arr2[p2];
    } else if (p2 >= m) {
        median += arr1[p1];
    } else {
        median += cmp::min(arr1[p1], arr2[p2]);
    }
    let medianF: f64 = median.into();
    return medianF / 2.0;
} else {
    return median.into();
}

Here’s the complete implementation:

pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
    let mut n = nums1.len();
    let mut m = nums2.len();
    let mut arr1: &Vec<i32> = &nums1;
    let mut arr2: &Vec<i32> = &nums2;
    if (m < n) {
        n = nums2.len();
        m = nums1.len();
        arr1 = &nums2;
        arr2 = &nums1;
    }
    let t = n + m;
    let p: usize = (t - 1) / 2;

    let mut low = 0;
    let mut hi = n;
    let mut p1 = 0;
    let mut p2 = 0;
    while (low < hi) {
        p1 = (low + hi) / 2;
        p2 = p - p1;
        let cond1 = p1 == 0 || arr1[p1 - 1] <= arr2[p2];
        let cond2 = p1 == n || p2 == 0 || arr2[p2 - 1] <= arr1[p1];
        if (cond1 && cond2) {
            break;
        } else if (!cond1) {
            p1 -= 1;
            hi = p1;
        } else {
            p1 += 1;
            low = p1;
        }
    }
    p2 = p - p1;

    let mut median = arr2[p2];
    if (p1 < n && arr1[p1] < arr2[p2]) {
        median = arr1[p1];
        p1 += 1;
    } else {
        p2 += 1;
    }

    if (t % 2 == 0) {
        if (p1 >= n) {
            median += arr2[p2];
        } else if (p2 >= m) {
            median += arr1[p1];
        } else {
            median += cmp::min(arr1[p1], arr2[p2]);
        }
        let medianF: f64 = median.into();
        return medianF / 2.0;
    } else {
        return median.into();
    }
}

Haskell Implementation

Now let’s examine the Haskell implementation. Unlike the LeetCode version, we’ll just assume our inputs are Double already instead of doing a conversion. Once again, we start by defining the terms:

medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
  where
    n' = V.length input1
    m' = V.length input2
    t = n' + m'
    p = (t - 1) `quot` 2
    (n, m, arr1, arr2) = if V.length input1 <= V.length input2
      then (n', m', input1, input2) else (m', n', input2, input1)

    ...

Now we’ll implement the binary search, this time doing a recursive function. We’ll do this in two parts, starting with a helper function. This helper function will simply tell us if a particular index is correct for p1. The trick though is that we’ll return an Ordering instead of just a Bool:

-- data Ordering = LT | EQ | GT
f :: Int -> Ordering

This lets us signal 3 possibilities. If we return EQ, this means the index is valid. If we return LT, this will mean we want fewer values from arr1. And then GT means we want more values from arr1.

With this framing it’s easy to see the implementation of this helper now. We determine the appropriate p2, figure out our two conditions, and return the value for each condition:

medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
  where
    ...
    f :: Int -> Ordering
    f pi1 =
      let pi2 = p - pi1
          cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
          cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
      in  if cond1 && cond2 then EQ else if (not cond1) then LT else GT

Now applying we can use this in a recursive binary search. The binary search tracks two pieces of state for our interval ((Int, Int)), and it will return the correct value for p1. The implementation applies the base case (return low if low >= hi), determines the midpoint, calls our helper, and then recurses appropriately based on the helper result.

medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
  where
    ...
    f :: Int -> Ordering
    f pi1 = ...
    
    search :: (Int, Int) -> Int
    search (low, hi) = if low >= hi then low else
      let mid = (low + hi) `quot` 2
      in  case f mid of
            EQ -> mid
            LT -> search (low, mid - 1)
            GT -> search (mid + 1, hi)

    p1 = search (0, n)
    p2 = p - p1

    ...

For the final part of the problem, we’ll define a helper. Given p1 and p2, it will emit the “lower” value between the two indices in the array (accounting for edge cases) as well as the two new indices (since one will increment).

This is a matter of lazily defining the “next” value for each array, the “end” condition of each array, and the “result” if that array’s value is chosen:

medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = ...
  where
    ...

    findNext pi1 pi2 =
      let next1 = arr1 V.! pi1
          next2 = arr2 V.! pi2
          end1 = pi1 >= n
          end2 = pi2 >= m
          res1 = (next1, pi1 + 1, pi2)
          res2 = (next2, pi1, pi2 + 1)
      in  if end1 then res2
            else if end2 then res1
            else if next1 <= next2 then res1 else res2

Now we just apply this either once or twice to get our result!

medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
  where
    ...

    tIsEven = even t
    (median1, nextP1, nextP2) = findNext p1 p2
    (median2, _, _) = findNext nextP1 nextP2
    result = if tIsEven
      then (median1 + median2) / 2.0
      else median1

Here’s the complete implementation:

medianSortedArrays :: V.Vector Double -> V.Vector Double -> Double
medianSortedArrays input1 input2 = result
  where
    n' = V.length input1
    m' = V.length input2
    t = n' + m'
    p = (t - 1) `quot` 2
    (n, m, arr1, arr2) = if V.length input1 <= V.length input2
      then (n', m', input1, input2) else (m', n', input2, input1)

    -- Evaluate the index in arr1
    -- If this does in indicate the index can be part of a median, return EQ
    -- If it indicates we need to move left in shortArr, return LT
    -- If it indicates we need to move right in shortArr, return GT
    -- Precondition: p1 <= n
    f :: Int -> Ordering
    f pi1 =
      let pi2 = p - pi1
          cond1 = pi1 == 0 || arr1 V.! (pi1 - 1) <= arr2 V.! pi2
          cond2 = pi1 == n || pi2 == 0 || (arr2 V.! (pi2 - 1) <= arr1 V.! pi1)
      in  if cond1 && cond2 then EQ else if (not cond1) then LT else GT
    
    search :: (Int, Int) -> Int
    search (low, hi) = if low >= hi then low else
      let mid = (low + hi) `quot` 2
      in  case f mid of
            EQ -> mid
            LT -> search (low, mid - 1)
            GT -> search (mid + 1, hi)
    
    findNext pi1 pi2 =
      let next1 = arr1 V.! pi1
          next2 = arr2 V.! pi2
          end1 = pi1 >= n
          end2 = pi2 >= m
          res1 = (next1, pi1 + 1, pi2)
          res2 = (next2, pi1, pi2 + 1)
      in  if end1 then res2
            else if end2 then res1
            else if next1 <= next2 then res1 else res2

    p1 = search (0, n)
    p2 = p - p1

    tIsEven = even t
    (median1, nextP1, nextP2) = findNext p1 p2
    (median2, _, _) = findNext nextP1 nextP2
    result = if tIsEven
      then (median1 + median2) / 2.0
      else median1

Conclusion

If you want to learn more about these kinds of problem solving techniques, you should take our course Solve.hs! In the coming weeks, we’ll see more problems related to data structures and algorithms, which are covered extensively in Modules 2 and 3 of that course!

Next
Next

Buffer & Save with a Challenging Example