Cleaning our Rust with Monadic Functions

clean_rust.jpg

A couple weeks ago we explored how to add authentication to a Rocket Rust server. This involved writing a from_request function that was very messy. You can see the original version of that function as an appendix at the bottom. But this week, we're going to try to improve that function! We'll explore functions like map and and_then in Rust. These can help us write cleaner code using similar ideas to functors and monads in Haskell.

For more details on this code, take a look at our Github Repo! For this article, you should look at rocket_auth_monads.rs. For a simpler introduction to Rust, take a look at our Rust Beginners Series!

Closures and Mapping

First, let's talk a bit about Rust's equivalent to fmap and functors. Suppose we have a simple option wrapper and a "doubling" function:

fn double(x: f64) -> {
  2.0 * x
}

fn main() -> () {
  let x: Option<f64> = Some(5.0);
  ...
}

We'd like to pass our x value to the double function, but it's wrapped in the Option type. A logical thing to do would be to return None if the input is None, and otherwise apply the function and re-wrap in Some. In Haskell, we describe this behavior with the Functor class. Rust's approach has some similarities and some differences.

Instead of Functor, Rust has a trait Iterable. An iterable type contains any number of items of its wrapped type. And map is one of the functions we can call on iterable types. As in Haskell, we provide a function that transforms the underlying items. Here's how we can apply our simple example with an Option:

fn main() -> () {
  let x: Option<f64> = Some(5.0);
  let y: Option<f64> = x.map(double);
}

One notable difference from Haskell is that map is a member function of the iterator type. In Haskell of course, there's no such thing as member functions, so fmap exists on its own.

In Haskell, we can use lambda expressions as arguments to higher order functions. In Rust, it's the same, but they're referred to as closures instead. The syntax is rather different as well. We capture the particular parameters within bars, and then provide a brace-delimited code-block. Here's a simple example:

fn main() -> () {
  let x: Option<f64> = Some(5.0);
  let y: Option<f64> = x.map(|x| {2.0 * x});
}

Type annotations are also possible (and sometimes necessary) when specifying the closure. Unlike Haskell, we provide these on the same line as the definition:

fn main() -> () {
  let x: Option<f64> = Some(5.0);
  let y: Option<f64> = x.map(|x: f64| -> f64 {2.0 * x});
}

And Then…

Now using map is all well and good, but our authentication example involved using the result of one effectful call in the next effect. As most Haskellers can tell you, this is a job for monads and not merely functors. We can capture some of the same effects of monads with the and_then function in Rust. This works a lot like the bind operator (>>=) in Haskell. It also takes an input function. And this function takes a pure input but produces an effectful output.

Here's how we apply it with Option. We start with a safe_square_root function that produces None when it's input is negative. Then we can take our original Option and use and_then to use the square root function.

fn safe_square_root(x: f64) -> Option<f64> {
  if x < 0.0 {
    None
  } else {
    Some(x.sqrt())
  }
}

fn main() -> () {
  let x: Option<f64> = Some(5.0);
  x.and_then(safe_square_root);
}

Converting to Outcomes

Now let's switch gears to our authentication example. Our final result type wasn't Option. Some intermediate results used this. But in the end, we wanted an Outcome. So to help us on our way, let's write a simple function to convert our options into outcomes. We'll have to provide the extra information of what the failure result should be. This is the status_error parameter.

fn option_to_outcome<R>(
  result: Option<R>,
  status_error: (Status, LoginError))
  -> Outcome<R, LoginError> {
    match result {
        Some(r) => Outcome::Success(r),
        None => Outcome::Failure(status_error)
    }
}

Now let's start our refactoring process. To begin, let's examine the retrieval of our username and password from the headers. We'll make a separate function for this. This should return an Outcome, where the success value is a tuple of two strings. We'll start by defining our failure outcome, a tuple of a status and our LoginError.

fn read_auth_from_headers(headers: &HeaderMap)
  -> Outcome<(String, String), LoginError> {
    let fail = (Status::BadRequest, LoginError::InvalidData);
    ...
}

We'll first retrieve the username out of the headers. Recall that this operation returns an Option. So we can convert it to an Outcome using our function. We can then use and_then with a closure taking the unwrapped username.

fn read_auth_from_headers(headers: &HeaderMap)
  -> Outcome<(String, String), LoginError> {
    let fail = (Status::BadRequest, LoginError::InvalidData);
    option_to_outcome(headers.get_one("username"), fail.clone())
        .and_then(|u| -> Outcome<(String, String), LoginError> {
            ...
        })
}

We can then do the same thing with the password field. When we've successfully unwrapped both fields, we can return our final Success outcome.

fn read_auth_from_headers(headers: &HeaderMap)
  -> Outcome<(String, String), LoginError> {
    let fail = (Status::BadRequest, LoginError::InvalidData);
    option_to_outcome(headers.get_one("username"), fail.clone())
        .and_then(|u| {
            option_to_outcome(
              headers.get_one("password"), fail.clone())
                .and_then(|p| {
                    Outcome::Success(
                      (String::from(u), String::from(p)))
                })
        })
}

Re-Organizing

Armed with this function we can start re-tooling our from_request function. We'll start by gathering the header results and invoking and_then. This unwraps the username and password:

impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
    type Error = LoginError;
    fn from_request(request: &'a Request<'r>)
      -> Outcome<AuthenticatedUser, LoginError> {
        let headers_result = 
              read_auth_from_headers(&request.headers());
        headers_result.and_then(|(u, p)| {
            ...
        }
        ...
    }
}

Now for the next step, we'll make a couple database calls. Both of our normal functions return Option values. So for each, we'll create a failure Outcome and invoke option_to_outcome. We'll follow this up with a call to and_then. First we get the user based on the username. Then we find their AuthInfo using the ID.

impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
    type Error = LoginError;
    fn from_request(request: &'a Request<'r>)
      -> Outcome<AuthenticatedUser, LoginError> {
        let headers_result = 
              read_auth_from_headers(&request.headers());
        headers_result.and_then(|(u, p)| {
            let conn_str = local_conn_string();
            let maybe_user =
                  fetch_user_by_email(&conn_str, &String::from(u));
            let fail1 =
                 (Status::NotFound, LoginError::UsernameDoesNotExist);
            option_to_outcome(maybe_user, fail1)
                .and_then(|user: UserEntity| {
                    let fail2 = (Status::MovedPermanently, 
                                 LoginError::WrongPassword);
                    option_to_outcome(
                      fetch_auth_info_by_user_id(
                        &conn_str, user.id), fail2)
                })
                .and_then(|auth_info: AuthInfoEntity| {
                   ...
                })
        })
    }
}

This gives us unwrapped authentication info. We can use this to compare the hash of the original password and return our final Outcome!

impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
    type Error = LoginError;
    fn from_request(request: &'a Request<'r>)
      -> Outcome<AuthenticatedUser, LoginError> {
        let headers_result = 
              read_auth_from_headers(&request.headers());
        headers_result.and_then(|(u, p)| {
            let conn_str = local_conn_string();
            let maybe_user =
                  fetch_user_by_email(&conn_str, &String::from(u));
            let fail1 =
                 (Status::NotFound, LoginError::UsernameDoesNotExist);
            option_to_outcome(maybe_user, fail1)
                .and_then(|user: UserEntity| {
                    let fail2 = (Status::MovedPermanently, 
                                 LoginError::WrongPassword);
                    option_to_outcome(
                      fetch_auth_info_by_user_id(
                        &conn_str, user.id), fail2)
                })
                .and_then(|auth_info: AuthInfoEntity| {
                   let hash = hash_password(&String::from(p));
                   if hash == auth_info.password_hash {
                        Outcome::Success( AuthenticatedUser{
                            user_id: auth_info.user_id})
                    } else {
                        Outcome::Failure(
                          (Status::Forbidden, 
                           LoginError::WrongPassword))
                    }
                })
        })
    }
}

Conclusion

Is this new solution that much better than our original? Well it avoids the "triangle of death" pattern with our code. But it's not necessarily that much shorter. Perhaps it's a little more cleaner on the whole though. Ultimately these code choices are up to you! Next time, we'll wrap up our current exploration of Rust by seeing how to profile our code in Rust.

This series has covered some more advanced topics in Rust. For a more in-depth introduction, check out our Rust Video Tutorial!

Appendix: Original Function

impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
    type Error = LoginError;
    fn from_request(request: &'a Request<'r>) -> Outcome<AuthenticatedUser, LoginError> {
        let username = request.headers().get_one("username");
        let password = request.headers().get_one("password");
        match (username, password) {
            (Some(u), Some(p)) => {
                let conn_str = local_conn_string();
                let maybe_user = fetch_user_by_email(&conn_str, &String::from(u));
                match maybe_user {
                    Some(user) => {
                        let maybe_auth_info = fetch_auth_info_by_user_id(&conn_str, user.id);
                        match maybe_auth_info {
                            Some(auth_info) => {
                                let hash = hash_password(&String::from(p));
                                if hash == auth_info.password_hash {
                                    Outcome::Success(AuthenticatedUser{user_id: 1})
                                } else {
                                    Outcome::Failure((Status::Forbidden, LoginError::WrongPassword))
                                }
                            }
                            None => {
                                Outcome::Failure((Status::MovedPermanently, LoginError::WrongPassword))
                            }
                        }
                    }
                    None => Outcome::Failure((Status::NotFound, LoginError::UsernameDoesNotExist))
                }
            },
            _ => Outcome::Failure((Status::BadRequest, LoginError::InvalidData))
        }
    }
}
Previous
Previous

Unit Tests and Benchmarks in Rust

Next
Next

Rocket Frontend: Templates and Static Assets