Sets & Heaps in Haskell and Rust
In the last few articles, we’ve spent some time doing manual algorithms with binary trees. But most often, you won’t have to work with binary trees yourself, as they are built in to the operations of other data structures.
Today, we’ll solve a problem that relies on using the “set” data structure as well as the “heap” data structure. To learn more details about using these structures in Haskell, sign up for our Solve.hs course, where you’ll spend an entire module on data structures!
The Problem
Today’s problem is Find k pairs with smallest sums. Our problem input consists of two sorted arrays of integers as well as a number k. Our job is to return the k pairs of numbers that have the smallest sum, where a “pair” consists of one number from the first array, and one number from the second array.
Here’s an example input and output:
inputNums1 = [1,10,101,102]
inputNums2 = [7,8,9,10,11]
k = 11
output =
  [ (1,7), (1,8), (1,9), (1,10), (1,11)
  , (10,7), (10,8), (10,9), (10,10)
  , (10,11), (101,7)
  ]Observe that we are returning the numbers themselves in pairs, rather than the sums, and not the indices of the numbers.
While the (implicit) indices of each pair must be unique, the numbers do not need to be unique. For each, if we have the lists [1,1], [5, 7], and k is 2, we should return [(1,5), (1,5)], where both of the 1’s from the first list are paired with the 5 from the second list.
The Algorithm
At the center of our algorithm is a min-heap. We want to store elements that contain pairs of numbers. But we want those elements to be ordered by their sum. We also want the element to contain the corresponding indices for the each number in their array.
We’ll start by making a pair of the first element from each array, and inserting that into our heap, because this pair must be the smallest. Then we’ll do a loop where we extract the minimum from the heap, and then try adding its “neighbors” into the heap. A “neighbor” of the index pair (i1, i2) comes from incrementing one of the two indices, so either (i1 + 1, i2) or (i1, i2 + 1). We continue until we have extracted k elements (or our heap is empty).
Each time we insert a pair of numbers into the heap, we’ll insert the pair of indices into a set. This will allow us to avoid double counting any pairs.
So using the first example in the section above, here are the first few steps of our process. Inside the heap is a 3-tuple with the sum, the indices (a 2-tuple), and the values (another 2-tuple).
Step 1:
Heap: [(8, (0,0) (1, 7))]
Visited: [(0,0)]
Output: []
Step 2:
Heap: [(9, (0,1), (1,8)), ]
Visited: [(0,0), (0,1), (1,0)]
Output: [(1,7)]
Step 3:
Heap: [(10, (0,2), (1,9)), (17, (1,0), (10,7)), (18, (1,1), (10,8))]
Visited: [(0,0), (0,1), (0,2), (1,0), (1,1)]
Output: [(1,7), (1,8)]And so we would continue this process until we gathered k outputs.
Haskell Solution
The core structure of this problem isn’t hard. We define some initial terms, such as our data structures and outputs, and then we have a single loop. We’ll express this single loop as a recursive function. It will take the number of remaining values, the visited set, the heap, and the accumulated output, and ultimately return the output. Throughout this problem we’ll use the type alias I2 to refer to a tuple of two integers.
import qualified Data.Heap as H
import qualified Data.Set as S
type I2 = (Int, Int)
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
  where
    n1 = V.length arr1
    n2 = V.length arr2
    f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
    f remaining visited heap acc = ...Let’s fill in our loop. We can start with the edge cases. If the k is 0, or if our heap is empty, we should return the results list (in reverse).
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
  where
    f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
    f 0 _ _ acc = reverse acc
    f remaining visited heap acc = case H.view heap of
      Nothing -> reverse acc
      Just ((_, (i1, i2), (v1, v2)), restHeap) -> ...Our primary case now results from extracting the min element of the heap. We don’t actually need its sum. We just put that in the first position of the tuple so that it is the primary sorting value for the heap. Let’s define the next possible coordinate pairs from adding to i1 and i2, as well as the new sums we get from using those indices:
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
  where
    f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
    f 0 _ _ acc = reverse acc
    f remaining visited heap acc = case H.view heap of
      Nothing -> reverse acc
      Just ((_, (i1, i2), (v1, v2)), restHeap) ->
        let c1 = (i1 + 1, i2)
            c2 = (i1, i2 + 1)
            inc1 = arr1 V.! (i1 + 1) + v2
            inc2 = v1 + arr2 V.! (i2 + 1)
        in  ...Now we need to try adding these values to the remaining heap. If the index is too large, or if we’ve already visited the coordinate, we don’t add the new value, returning the old heap. Otherwise we insert it into our heap. For the second value, we just use heap1 (from trying to add the first value) as the baseline.
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
  where
    f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
    f 0 _ _ acc = reverse acc
    f remaining visited heap acc = case H.view heap of
      Nothing -> reverse acc
      Just ((_, (i1, i2), (v1, v2)), restHeap) ->
        let c1 = (i1 + 1, i2)
            c2 = (i1, i2 + 1)
            inc1 = arr1 V.! (i1 + 1) + v2
            inc2 = v1 + arr2 V.! (i2 + 1)
            heap1 = if i1 + 1 < n1 && S.notMember c1 visited
                      then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
            heap2 = if i2 + 1 < n2 && S.notMember c2 visited
                      then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
        in  ...Now we complete our recursive loop function, by adding these new indices to the visited set and make a recursive call. We decrement the remaining number, and append the values to our accumulated list.
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = ...
  where
    f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
    f 0 _ _ acc = reverse acc
    f remaining visited heap acc = case H.view heap of
      Nothing -> reverse acc
      Just ((_, (i1, i2), (v1, v2)), restHeap) ->
        let c1 = (i1 + 1, i2)
            c2 = (i1, i2 + 1)
            inc1 = arr1 V.! (i1 + 1) + v2
            inc2 = v1 + arr2 V.! (i2 + 1)
            heap1 = if i1 + 1 < n1 && S.notMember c1 visited
                      then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
            heap2 = if i2 + 1 < n2 && S.notMember c2 visited
                      then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
            visited' = foldr S.insert visited ([c1, c2] :: [I2])
        in  f (remaining - 1) visited' heap2 ((v1,v2) : acc)To complete the function, we define our initial heap and make the first call to our loop function:
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = f k (S.singleton (0,0)) initialHeap []
  where
    val1 = arr1 V.! 0
    val2 = arr2 V.! 0
    initialHeap = H.singleton (val1 + val2, (0,0), (val1, val2))
    f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
    f = ...Now we’re done! Here’s our complete solution:
type I2 = (Int, Int)
findKPairs :: V.Vector Int -> V.Vector Int -> Int -> [I2]
findKPairs arr1 arr2 k = f k (S.singleton (0,0)) initialHeap []
  where
    val1 = arr1 V.! 0
    val2 = arr2 V.! 0
    initialHeap = H.singleton (val1 + val2, (0,0), (val1, val2))
    n1 = V.length arr1
    n2 = V.length arr2
    f :: Int -> S.Set I2 -> H.MinHeap (Int, I2, I2) -> [I2] -> [I2]
    f 0 _ _ acc = reverse acc
    f remaining visited heap acc = case H.view heap of
      Nothing -> reverse acc
      Just ((_, (i1, i2), (v1, v2)), restHeap) ->
        let c1 = (i1 + 1, i2)
            c2 = (i1, i2 + 1)
            inc1 = arr1 V.! (i1 + 1) + v2
            inc2 = v1 + arr2 V.! (i2 + 1)
            heap1 = if i1 + 1 < n1 && S.notMember c1 visited
                      then H.insert (inc1, c1, (arr1 V.! (i1 + 1), v2)) restHeap else restHeap
            heap2 = if i2 + 1 < n2 && S.notMember c2 visited
                      then H.insert (inc2, c2, (v1, arr2 V.! (i2 + 1))) heap1 else heap1
            visited' = foldr S.insert visited ([c1, c2] :: [I2])
        in  f (remaining - 1) visited' heap2 ((v1,v2) : acc)Rust Solution
Now, on to our Rust solution. We’ll start by defining our terms. These follow the pattern laid out in our algorithm and the Haskell solution:
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
   let mut heap: BinaryHeap<Reverse<(i32, (usize,usize), (i32,i32))>> = BinaryHeap::new();
    let val1 = nums1[0];
    let val2 = nums2[0];
    heap.push(Reverse((val1 + val2, (0,0), (val1, val2))));
    let mut visited: HashSet<(usize, usize)> = HashSet::new();
    visited.insert((0,0));
    let mut results = Vec::new();
    let mut remaining = k;
    let n1 = nums1.len();
    let n2 = nums2.len();
    ...
    return results;
}The most interesting of these is the heap. We parameterize the type of the BinaryHeap using the same kind of tuple we had in Haskell. But in order to make it a “Min” heap, we have to wrap our values in the Reverse type.
Now let’s define the outline of our loop. We keep going until remaining is 0. We will also break if we can’t pop a value from the heap.
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
    ...
    
    while remaining > 0 {
        if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
            ...
        } else {
            break;
        }
        remaining -= 1;
    }
    return results;
}Now we’ll define our new coordinates, and add tests for whether or not these can be added to the heap:
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
    ...
    
    while remaining > 0 {
        if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
            let c1 = (i1 + 1, i2);
            let c2 = (i1, i2 + 1);
            if i1 + 1 < n1 && !visited.contains(&c1) {
                let inc1 = nums1[i1 + 1] + v2;
                ...
            }
            if i2 + 1 < n2 && !visited.contains(&c2) {
                let inc2 = v1 + nums2[i2 + 1];
                ...
            }
        } else {
            break;
        }
        remaining -= 1;
    }
    return results;
}In each case, we add the new coordinate to visited, and push the new element on to the heap. We also push the values onto our results array.
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
    ...
    
    while remaining > 0 {
        if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
            let c1 = (i1 + 1, i2);
            let c2 = (i1, i2 + 1);
            if i1 + 1 < n1 && !visited.contains(&c1) {
                let inc1 = nums1[i1 + 1] + v2;
                visited.insert(c1);
                heap.push(Reverse((inc1, c1, (nums1[i1 + 1], v2))));
            }
            if i2 + 1 < n2 && !visited.contains(&c2) {
                let inc2 = v1 + nums2[i2 + 1];
                visited.insert(c2);
                heap.push(Reverse((inc2, c2, (v1, nums2[i2 + 1]))));
            }
            results.push(vec![v1,v2]);
        } else {
            break;
        }
        remaining -= 1;
    }
    return results;
}And that’s all we need!
pub fn k_smallest_pairs(nums1: Vec<i32>, nums2: Vec<i32>, k: i32) -> Vec<Vec<i32>> {
    let mut heap: BinaryHeap<Reverse<(i32, (usize,usize), (i32,i32))>> = BinaryHeap::new();
    let val1 = nums1[0];
    let val2 = nums2[0];
    heap.push(Reverse((val1 + val2, (0,0), (val1, val2))));
    let mut visited: HashSet<(usize, usize)> = HashSet::new();
    visited.insert((0,0));
    let mut results = Vec::new();
    let mut remaining = k;
    let n1 = nums1.len();
    let n2 = nums2.len();
    
    while remaining > 0 {
        if let Some(Reverse((sumNum, (i1, i2), (v1, v2)))) = heap.pop() {
            let c1 = (i1 + 1, i2);
            let c2 = (i1, i2 + 1);
            if i1 + 1 < n1 && !visited.contains(&c1) {
                let inc1 = nums1[i1 + 1] + v2;
                visited.insert(c1);
                heap.push(Reverse((inc1, c1, (nums1[i1 + 1], v2))));
            }
            if i2 + 1 < n2 && !visited.contains(&c2) {
                let inc2 = v1 + nums2[i2 + 1];
                visited.insert(c2);
                heap.push(Reverse((inc2, c2, (v1, nums2[i2 + 1]))));
            }
            results.push(vec![v1,v2]);
        } else {
            break;
        }
        remaining -= 1;
    }
    return results;
}Conclusion
Next time, we’ll start exploring some graph problems, which also rely on data structures like we used here!
I didn’t explain too much in this article about the details of using these various data structures. If you want an in-depth exploration of how data structures work in Haskell, including the “common API” that helps you use almost all of them, you should sign up for our Solve.hs course! Module 2 is completely dedicated to teaching you about data structures, and you’ll get a lot of practice working with these structures in sample problems.