Learning Rust by Sorting Part 1: Multithreaded Mergesort

9 minute read

Published:

Implementing mergesort seemed like a good problem to code up for learning Rust because it’s pretty simple and since it uses divide and conquer can be extended in a couple different ways, one of which is using concurrency for better performance. It ended up being a good example of how to do shared state concurrency in Rust, which has strict rules over ownership and borrowing.

Serial Mergesort

A top-down recursive mergesort implementation can be found in an algorithms textbook like CLRS. Typically it is written as an operation merge_sort that gets called recursively on the left and right sides of an array and calls the merge function to combine the two halves into a final sorted array. In Rust, a merge_sort function might look like the following which passes a mutable reference to a slice to each function, and uses this to modify the variable pointed to by the slice called nums.

pub fn merge_sort(nums: &mut [i64], left: usize, right: usize) {
    if left < right {
        let middle = left + (right - left) / 2;
        merge_sort(nums, left, middle);
        merge_sort(nums, middle + 1, right);
        merge(nums, left, middle, right);
    }
}

fn merge(nums: &mut [i64], left: usize, middle: usize, right: usize) {
    let n1 = middle - left + 1;
    let n2 = right - middle;
    let mut left_nums: Vec<i64> = vec![0; n1];
    let mut right_nums: Vec<i64> = vec![0; n2];

    // copy data into left and right subarrays
    for i in 0..n1 {
        left_nums[i] = nums[left + i];
    }
    for j in 0..n2 {
        right_nums[j] = nums[middle + 1 + j];
    }

    let mut i = 0usize;
    let mut j = 0usize;
    let mut k = left;
    while i < n1 && j < n2 {
        if left_nums[i] <= right_nums[j] {
            nums[k] = left_nums[i];
            i += 1;
        } else {
            nums[k] = right_nums[j];
            j += 1;
        }
        k += 1;
    }

    while i < n1 {
        nums[k] = left_nums[i];
        i += 1;
        k += 1;
    }

    while j < n2 {
        nums[k] = right_nums[j];
        j += 1;
        k += 1;
    }
}

Multithreaded Mergesort

One simple way that we can try to improve the performance of merge_sort using multiple threads is to spawn a new thread for the first recursive call of merge_sort and join them after the second call like below:

pub fn parallel_merge_sort(nums: &mut [i64], left: usize, right: usize) {
    if left < right {
        let middle = left + (right - left) / 2;
        let t = thread::spawn(move || parallel_merge_sort(nums, left, middle));
        parallel_merge_sort(nums, middle + 1, right);
        t.join().unwrap();
        merge(nums, left, middle, right);
    }
}

However, the Rust borrow checker gives us this error. It is quite apparent what the problem is - we are trying to create a new context with the thread::spawn and call move which converts any reference to a variable passed by value . In this case the variables passed by value are nums, left, and middle. Rust actually does something clever, which is that it checks whether it can copy the value passed into the function. In this case, we can’t copy a mutable reference of a slice. I think it’s because there can only be one mutable reference of a variable at a time, so copy is not implemented. Since we later try to borrow this same value, the program fails to compile.

error[E0521]: borrowed data escapes outside of function
 --> src/lib.rs:8:17
  |
5 | pub fn parallel_merge_sort(nums: &mut [i64], left: usize, right: usize) {
  |                   ----  - let's call the lifetime of this reference `'1`
  |                   |
  |                   `nums` is a reference that is only valid in the function body
...
8 |         let t = thread::spawn(move || parallel_merge_sort(nums, left, middle));
  |                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |                 |
  |                 `nums` escapes the function body here
  |                 argument requires that `'1` must outlive `'static`

error[E0382]: borrow of moved value: `nums`
 --> src/lib.rs:9:20
  |
5 | pub fn parallel_merge_sort(nums: &mut [i64], left: usize, right: usize) {
  |                   ---- move occurs because `nums` has type `&mut [i64]`, which does not implement the `Copy` trait
...
8 |         let t = thread::spawn(move || parallel_merge_sort(nums, left, middle));
  |                               -------            ---- variable moved due to use in closure
  |                               |
  |                               value moved into closure here
9 |         parallel_merge_sort(nums, middle + 1, right);
  |                    ^^^^ value borrowed here after move

Some errors have detailed explanations: E0382, E0521.

In order to get around this error, we need a new approach. One trick is to cut up the slice into multiple smaller slices that each thread can individually own. That way, there is no conflict over trying to borrow a value that’s previously been moved. Great! How do we accomplish this? One answer is to utilize a handy function in the Rust stdlib called chunks_mut. This function divides a slice into equal sized chunks with the last chunk having any remainder of the slice. Using this piece of knowledge we can try to spawn a thread to do merge_sort on each chunk at a time (using the serial implementation of merge_sort from before) and then do one big merge at the end.

pub fn parallel_mergesort(data: &mut [i64], threads: usize) {
    let chunks = std::cmp::min(data.len(), threads);
    let mut chunk_lens = Vec::new();
    let data_len = data.len();
    let mut thread_handles = Vec::new();
    for slice in data.chunks_mut(data_len / chunks) {
        let slice_len = slice.len();
        chunk_lens.push(slice_len);
        let t = thread::spawn(move || merge_sort(slice, 0, slice_len - 1));
        thread_handles.push(t);
    }

    let _ = thread_handles.into_iter().map(|th| th.join().unwrap());
    
    let mut middle: usize = 0;
    let mut end: usize = chunk_lens[0] - 1;
    for cl in chunk_lens {
        merge(data, 0, middle, end);
        middle = end;
        end = end + cl;
    }
}

However, we still run into borrowing errors like before. Since we are borrowing a mutable reference to data when getting the mutable slices from chunks_mut, Rust complains that a variable has not gone out of scope by the time we try to borrow its mutable reference again when we call merge(data, 0, middle, end) at the bottom of the function parallel_mergesort. See the full error message:

error[E0521]: borrowed data escapes outside of function
  --> src/lib.rs:23:21
   |
14 | pub fn parallel_mergesort(data: &mut [i64], threads: usize) {
   |                           ----  - let's call the lifetime of this reference `'1`
   |                           |
   |                           `data` is a reference that is only valid in the function body
...
23 |             let t = thread::spawn(move || merge_sort(slice, 0, slice_len - 1));
   |                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   |                     |
   |                     `data` escapes the function body here
   |                     argument requires that `'1` must outlive `'static`

error[E0499]: cannot borrow `*data` as mutable more than once at a time
  --> src/lib.rs:33:15
   |
20 |         for slice in data.chunks_mut(data_len / chunks) {
   |                      ---------------------------------- first mutable borrow occurs here
...
23 |             let t = thread::spawn(move || merge_sort(slice, 0, slice_len - 1));
   |                     ---------------------------------------------------------- argument requires that `*data` is borrowed for `'static`
...
33 |         merge(data, 0, middle, end);
   |               ^^^^ second mutable borrow occurs here

Some errors have detailed explanations: E0499, E0521.

We can solve this particular issue by using thread::scope which does a couple things for us. It defines a scope where we are guaranteed for threads to be joined once we go out of this scope. Most importantly, this allows us to borrow non-static data like the variable data in our parallel_merge function. Once we later call merge to do the final merge of all sorted chunks, we are guaranteed to be out of the scope where we were borrowing the variable data. Also, we won’t have to manually join the threads by their handles anymore, which makes for cleaner code. See the final result of parallel_merge below:

pub fn parallel_mergesort(data: &mut [i64], threads: usize) {
    let chunks = std::cmp::min(data.len(), threads);
    let mut chunk_lens = Vec::new();
    let data_len = data.len();
    thread::scope(|s| {
        for slice in data.chunks_mut(data_len / chunks) {
            let slice_len = slice.len();
            chunk_lens.push(slice_len);
            s.spawn(move || merge_sort(slice, 0, slice_len - 1));
        }
    });

    let mut middle: usize = 0;
    let mut end: usize = chunk_lens[0] - 1;
    for cl in chunk_lens {
        merge(data, 0, middle, end);
        middle = end;
        end = end + cl;
    }
}

Result

You can see the full implementation including a way to run the code and tests in its own repo.

Using 4 threads while sorting 10 million random integers on my intel macbook laptop, I get a roughly 2.5x speedup (from ~12.6 -> ~5.1 seconds) from merge_sort to parallel_mergesort. That being said, there are a lot of things that are not optimal with the performance of our multithreaded implementation. One of them is the fact that we sort the chunks in parallel and then do a serial merge. Regular mergesort has a recurrence of MS(n) = 2*MS(n/2) + O(n), which shakes out to O(nlgn). However, our multithreaded mergesort changes that to MS_t(n) = MS(n/t) + O(n). However, this only reduces the complexity to O(n) due to the serial merge operation (see CLRS section 27.3 for the full details). In the next blog post, hopefully we can improve the parallelism of this implementation by parallelizing the merge operation at the end using binary search.