Haskellings 2: Better Configuration
This week we'll continue working on our nascent Haskellings project. There's a few interesting things we'll learn here around using the directory package. We'll also explore how to use the Seq
data structure to implement a quick Bread-First-Search algorithm!
Starting Haskellings!
After learning about the Rustlings program in the last few weeks, we're now going to try to start replicating it in Haskell! In this week's video blog, we'll learn a little bit about using the ghc
command on its own outside of Stack/Cabal, and then how to run it from within our program using the System.Process
library.
Rustlings Part 2
This week we continue with another Rustlings video tutorial! We'll tackle some more advanced concepts like move semantics, traits, and generics! Next week, we'll start considering how we might build a similar program to teach beginners about Haskell!
Rustlings Video Blog!
We're doing something very new this week. Instead of doing a code writeup, I've actually made . In keeping with the last couple months on content, this first one is still Rust related. We'll walkthrough the Rustlings tool, which is an interactive program that teaches you the basics of the Rust Language! Soon, we'll start exploring how we might do this in Haskell!
You can also watch this video on our YouTube Channel! Subscribe there or sign up for our mailing list!
Rust Web Series Complete!
We're taking a quick breather this week from new content for an announcement. Our recently concluded Rust Web series now has a permanent spot on the advanced page of our website. You can take a look at the series page here! Here's a quick summary of the series:
- Part 1: Postgres - In the first part, we learn about a basic library to enable integration with a Postgresql Database.
- Part 2: Diesel - Next up, we get a little more formal with our database mechanics. We use the Diesel library to provide a schema for our database application.
- Part 3: Rocket - In part 3, we take the next step and start making a web server! We'll learn the basics of the Rocket server library!
- Part 4: CRUD Server - What do we do once we have a database and server library? Combine them of course! In this part, we'll make a CRUD server that can access our database elements using Diesel and Rocket.
- Part 5: Authentication - If your server will actually serve real users, you'll need authentication at some point. We'll see the different mechanisms we can use with Rocket for securing our endpoints.
- Part 6: Front-end Templating - If you're serving a full front-end web app, you'll need some way to customize the HTML. In the last part of the series, we'll see how Rocket makes this easy!
The best part is that you can find all the code for the series on our Github Repo! So be sure to take a look there. And if you're still new to Rust, you can also get your feet wet first with our Beginners Series.
In other exciting news, we'll be trying a completely new kind of content in the next couple weeks. I've written a bit in the past about using different IDEs like Atom and IntelliJ to write Haskell. I'd like to revisit these ideas to give a clearer idea of how to make our lives easier when writing code. But instead of writing articles, I'll be making a few videos to showcase how these work! I hope that a visual display of the IDEs will help make the content more clear.
Unit Tests and Benchmarks in Rust
For a couple months now, we've focused on some specific libraries you can use in Rust for web development. But we shouldn't lose sight of some other core language skills and mechanics. Whenever you write code, you should be able to show first that it works, and second that it works efficiently. If you're going to build a larger Rust app, you should also know a bit about unit testing and benchmarking. This week, we'll take a couple simple sorting algorithms as our examples to learn these skills.
As always, you can take a look at the code for this article on our Github Repo for the series. You can find this week's code specifically in sorters.rs
! For a more basic introduction to Rust, be sure to check out our Rust Beginners Series!
Insertion Sort
We'll start out this article by implementing insertion sort. This is one of the simpler sorting algorithms, which is rather inefficient. We'll perform this sort "in place". This means our function won't return a value. Rather, we'll pass a mutable reference to our vector so we can manipulate its items. To help out, we'll also define a swap
function to change two elements around that same reference:
pub fn swap(numbers: &mut Vec<i32>, i: usize, j: usize) {
let temp = numbers[i];
numbers[i] = numbers[j];
numbers[j] = temp;
}
pub fn insertion_sorter(numbers: &mut Vec<i32>) {
...
}
At its core, insertion sort is a pretty simple algorithm. We maintain the invariant that the "left" part of the array is always sorted. (At the start, with only 1 element, this is clearly true). Then we loop through the array and "absorb" the next element into our sorted part. To absorb the element, we'll loop backwards through our sorted portion. Each time we find a larger element, we switch their places. When we finally encounter a smaller element, we know the left side is once again sorted.
pub fn insertion_sorter(numbers: &mut Vec<i32>) {
for i in 1..numbers.len() {
let mut j = i;
while j > 0 && numbers[j-1] > numbers[j] {
swap(numbers, j, j - 1);
j = j - 1;
}
}
}
Testing
Our algorithm is simple enough. But how do we know it works? The obvious answer is to write some unit tests for it. Rust is actually a bit different from Haskell and most other languages in the canonical approach to unit tests. Most of the time, you'll make a separate test directory. But Rust encourages you to write unit tests in the same file as the function definition. We do this by having a section at the bottom of our file specifically for tests. We delineate a test function with the test
macro:
[#test]
fn test_insertion_sort() {
...
}
To keep things simple, we'll define a random vector of 100 integers and pass it to our function. We'll use assert
to verify that each number is smaller than the next one after it.
#[test]
fn test_insertion_sort() {
let mut numbers: Vec<i32> = random_vector(100);
insertion_sorter(&mut numbers);
for i in 0..(numbers.len() - 1) {
assert!(numbers[i] <= numbers[i + 1]);
}
}
When we run the cargo test
command, Cargo will automatically detect that we have a test suite in this file and run it.
running 1 test...
test sorter::test_insertion_sort ... ok
Benchmarking
So we know our code works, but how quickly does it work? When you want to check the performance of your code, you need to establish benchmarks. These are like test suites except that they're meant to give out the average time it takes to perform a task.
Just as we had a test
macro for making test suites, we can use the bench
macro for benchmarks. Each of these takes a mutable Bencher
object as an argument. To record some code, we'll call iter
on that object and pass a closure that will run our function.
#[bench]
fn bench_insertion_sort_100_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(100);
insertion_sorter(&mut numbers)
});
}
We can then run the benchmark with cargo bench
.
running 2 tests
test sorter::test_insertion_sort ... ignored
test sorter::bench_insertion_sort_100_ints ... bench: 6,537 ns
/iter (+/- 1,541)
So on average, it took about 6ms to sort 100 numbers. On its own, this number doesn't tell us much. But we can get a more clear idea for the runtime of our algorithm by looking at benchmarks of different sizes. Suppose we make lists of 1000 and 10000:
#[bench]
fn bench_insertion_sort_1000_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(1000);
insertion_sorter(&mut numbers)
});
}
#[bench]
fn bench_insertion_sort_10000_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(10000);
insertion_sorter(&mut numbers)
});
}
Now when we run the benchmark, we can compare the results of these different runs:
running 4 tests
test sorter::test_insertion_sort ... ignored
test sorter::bench_insertion_sort_10000_ints ... bench: 65,716,130 ns
/iter (+/- 11,193,188)
test sorter::bench_insertion_sort_1000_ints ... bench: 612,373 ns
/iter (+/- 124,732)
test sorter::bench_insertion_sort_100_ints ... bench: 12,032 ns
/iter (+/- 904)
We see that when we increase the problem size by a factor of 10, we increase the runtime by a factor of nearly 100! This confirms for us that our simple insertion sort has an asymptotic runtime of O(n^2)
, which is not very good.
Quick Sort
There are many ways to sort more efficiently! Let's try our hand at quicksort. For this algorithm, we first "partition" our array. We'll choose a pivot value, and then move all the numbers smaller than the pivot to the left of the array, and all the greater numbers to the right. The upshot is that we know our pivot element is now in the correct final spot!
Here's what the partition algorithm looks like. It works on a specific sub-segment of our vector, indicated by start
and end
. We initially move the pivot element to the back, and then loop through the other elements of the array. The i
index tracks where our pivot will end up. Each time we encounter a smaller number, we increment it. At the very end we swap our pivot element back into its place, and return its final index.
pub fn partition(
numbers: &mut Vec<i32>,
start: usize,
end: usize,
partition: usize)
-> usize {
let pivot_element = numbers[partition];
swap(numbers, partition, end - 1);
let mut i = start;
for j in start..(end - 1) {
if numbers[j] < pivot_element {
swap(numbers, i, j);
i = i + 1;
}
}
swap(numbers, i, end - 1);
i
}
So to finish sorting, we'll set up a recursive helper that, again, functions on a sub-segment of the array. We'll choose a random element and partition by it:
pub fn quick_sorter_helper(
numbers: &mut Vec<i32>, start: usize, end: usize) {
if start >= end {
return;
}
let mut rng = thread_rng();
let initial_partition = rng.gen_range(start, end);
let partition_index =
partition(numbers, start, end, initial_partition);
...
}
Now that we've partitioned, all that's left to do is recursively sort each side of the partition! Our main API function will call this helper with the full size of the array.
pub fn quick_sorter_helper(
numbers: &mut Vec<i32>, start: usize, end: usize) {
if start >= end {
return;
}
let mut rng = thread_rng();
let initial_partition = rng.gen_range(start, end);
let partition_index =
partition(numbers, start, end, initial_partition);
quick_sorter_helper(numbers, start, partition_index);
quick_sorter_helper(numbers, partition_index + 1, end);
}
pub fn quick_sorter(numbers: &mut Vec<i32>) {
quick_sorter_helper(numbers, 0, numbers.len());
}
Now that we've got this function, let's add tests and benchmarks for it:
#[test]
fn test_quick_sort() {
let mut numbers: Vec<i32> = random_vector(100);
quick_sorter(&mut numbers);
for i in 0..(numbers.len() - 1) {
assert!(numbers[i] <= numbers[i + 1]);
}
}
#[bench]
fn bench_quick_sort_100_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(100);
quick_sorter(&mut numbers)
});
}
// Same kind of benchmarks for 1000, 10000, 100000
Then we can run our benchmarks and see our results:
running 9 tests
test sorter::test_insertion_sort ... ignored
test sorter::test_quick_sort ... ignored
test sorter::bench_insertion_sort_10000_ints ... bench: 65,130,880 ns
/iter (+/- 49,548,187)
test sorter::bench_insertion_sort_1000_ints ... bench: 312,300 ns
/iter (+/- 243,337)
test sorter::bench_insertion_sort_100_ints ... bench: 6,159 ns
/iter (+/- 4,139)
test sorter::bench_quick_sort_100000_ints ... bench: 14,292,660 ns
/iter (+/- 5,815,870)
test sorter::bench_quick_sort_10000_ints ... bench: 1,263,985 ns
/iter (+/- 622,788)
test sorter::bench_quick_sort_1000_ints ... bench: 105,443 ns
/iter (+/- 65,812)
test sorter::bench_quick_sort_100_ints ... bench: 9,259 ns
/iter (+/- 3,882)
Quicksort does much better on the larger values, as expected! We can discern that the times seem to only go up by a factor of around 10. It's difficult to determine that the true runtime is actually O(n log n)
. But we can clearly see that we're much closer to linear time!
Conclusion
That's all for this intermediate series on Rust! Next week, we'll summarize the skills we learned over the course of these couple months in Rust. Then we'll look ahead to our next series of topics, including some totally new kinds of content!
Don't forget! If you've never programmed in Rust before, our Rust Video Tutorial provides an in-depth introduction to the basics!
Cleaning our Rust with Monadic Functions
A couple weeks ago we explored how to add authentication to a Rocket Rust server. This involved writing a from_request
function that was very messy. You can see the original version of that function as an appendix at the bottom. But this week, we're going to try to improve that function! We'll explore functions like map
and and_then
in Rust. These can help us write cleaner code using similar ideas to functors and monads in Haskell.
For more details on this code, take a look at our Github Repo! For this article, you should look at rocket_auth_monads.rs
. For a simpler introduction to Rust, take a look at our Rust Beginners Series!
Closures and Mapping
First, let's talk a bit about Rust's equivalent to fmap
and functors. Suppose we have a simple option wrapper and a "doubling" function:
fn double(x: f64) -> {
2.0 * x
}
fn main() -> () {
let x: Option<f64> = Some(5.0);
...
}
We'd like to pass our x
value to the double
function, but it's wrapped in the Option
type. A logical thing to do would be to return None
if the input is None
, and otherwise apply the function and re-wrap in Some
. In Haskell, we describe this behavior with the Functor
class. Rust's approach has some similarities and some differences.
Instead of Functor
, Rust has a trait Iterable
. An iterable type contains any number of items of its wrapped type. And map
is one of the functions we can call on iterable types. As in Haskell, we provide a function that transforms the underlying items. Here's how we can apply our simple example with an Option
:
fn main() -> () {
let x: Option<f64> = Some(5.0);
let y: Option<f64> = x.map(double);
}
One notable difference from Haskell is that map
is a member function of the iterator type. In Haskell of course, there's no such thing as member functions, so fmap
exists on its own.
In Haskell, we can use lambda expressions as arguments to higher order functions. In Rust, it's the same, but they're referred to as closures instead. The syntax is rather different as well. We capture the particular parameters within bars, and then provide a brace-delimited code-block. Here's a simple example:
fn main() -> () {
let x: Option<f64> = Some(5.0);
let y: Option<f64> = x.map(|x| {2.0 * x});
}
Type annotations are also possible (and sometimes necessary) when specifying the closure. Unlike Haskell, we provide these on the same line as the definition:
fn main() -> () {
let x: Option<f64> = Some(5.0);
let y: Option<f64> = x.map(|x: f64| -> f64 {2.0 * x});
}
And Then…
Now using map
is all well and good, but our authentication example involved using the result of one effectful call in the next effect. As most Haskellers can tell you, this is a job for monads and not merely functors. We can capture some of the same effects of monads with the and_then
function in Rust. This works a lot like the bind operator (>>=)
in Haskell. It also takes an input function. And this function takes a pure input but produces an effectful output.
Here's how we apply it with Option
. We start with a safe_square_root
function that produces None
when it's input is negative. Then we can take our original Option
and use and_then
to use the square root function.
fn safe_square_root(x: f64) -> Option<f64> {
if x < 0.0 {
None
} else {
Some(x.sqrt())
}
}
fn main() -> () {
let x: Option<f64> = Some(5.0);
x.and_then(safe_square_root);
}
Converting to Outcomes
Now let's switch gears to our authentication example. Our final result type wasn't Option
. Some intermediate results used this. But in the end, we wanted an Outcome
. So to help us on our way, let's write a simple function to convert our options into outcomes. We'll have to provide the extra information of what the failure result should be. This is the status_error
parameter.
fn option_to_outcome<R>(
result: Option<R>,
status_error: (Status, LoginError))
-> Outcome<R, LoginError> {
match result {
Some(r) => Outcome::Success(r),
None => Outcome::Failure(status_error)
}
}
Now let's start our refactoring process. To begin, let's examine the retrieval of our username and password from the headers. We'll make a separate function for this. This should return an Outcome
, where the success value is a tuple of two strings. We'll start by defining our failure outcome, a tuple of a status and our LoginError
.
fn read_auth_from_headers(headers: &HeaderMap)
-> Outcome<(String, String), LoginError> {
let fail = (Status::BadRequest, LoginError::InvalidData);
...
}
We'll first retrieve the username out of the headers. Recall that this operation returns an Option
. So we can convert it to an Outcome
using our function. We can then use and_then
with a closure taking the unwrapped username.
fn read_auth_from_headers(headers: &HeaderMap)
-> Outcome<(String, String), LoginError> {
let fail = (Status::BadRequest, LoginError::InvalidData);
option_to_outcome(headers.get_one("username"), fail.clone())
.and_then(|u| -> Outcome<(String, String), LoginError> {
...
})
}
We can then do the same thing with the password field. When we've successfully unwrapped both fields, we can return our final Success
outcome.
fn read_auth_from_headers(headers: &HeaderMap)
-> Outcome<(String, String), LoginError> {
let fail = (Status::BadRequest, LoginError::InvalidData);
option_to_outcome(headers.get_one("username"), fail.clone())
.and_then(|u| {
option_to_outcome(
headers.get_one("password"), fail.clone())
.and_then(|p| {
Outcome::Success(
(String::from(u), String::from(p)))
})
})
}
Re-Organizing
Armed with this function we can start re-tooling our from_request
function. We'll start by gathering the header results and invoking and_then
. This unwraps the username and password:
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let headers_result =
read_auth_from_headers(&request.headers());
headers_result.and_then(|(u, p)| {
...
}
...
}
}
Now for the next step, we'll make a couple database calls. Both of our normal functions return Option
values. So for each, we'll create a failure Outcome
and invoke option_to_outcome
. We'll follow this up with a call to and_then
. First we get the user based on the username. Then we find their AuthInfo
using the ID.
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let headers_result =
read_auth_from_headers(&request.headers());
headers_result.and_then(|(u, p)| {
let conn_str = local_conn_string();
let maybe_user =
fetch_user_by_email(&conn_str, &String::from(u));
let fail1 =
(Status::NotFound, LoginError::UsernameDoesNotExist);
option_to_outcome(maybe_user, fail1)
.and_then(|user: UserEntity| {
let fail2 = (Status::MovedPermanently,
LoginError::WrongPassword);
option_to_outcome(
fetch_auth_info_by_user_id(
&conn_str, user.id), fail2)
})
.and_then(|auth_info: AuthInfoEntity| {
...
})
})
}
}
This gives us unwrapped authentication info. We can use this to compare the hash of the original password and return our final Outcome
!
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let headers_result =
read_auth_from_headers(&request.headers());
headers_result.and_then(|(u, p)| {
let conn_str = local_conn_string();
let maybe_user =
fetch_user_by_email(&conn_str, &String::from(u));
let fail1 =
(Status::NotFound, LoginError::UsernameDoesNotExist);
option_to_outcome(maybe_user, fail1)
.and_then(|user: UserEntity| {
let fail2 = (Status::MovedPermanently,
LoginError::WrongPassword);
option_to_outcome(
fetch_auth_info_by_user_id(
&conn_str, user.id), fail2)
})
.and_then(|auth_info: AuthInfoEntity| {
let hash = hash_password(&String::from(p));
if hash == auth_info.password_hash {
Outcome::Success( AuthenticatedUser{
user_id: auth_info.user_id})
} else {
Outcome::Failure(
(Status::Forbidden,
LoginError::WrongPassword))
}
})
})
}
}
Conclusion
Is this new solution that much better than our original? Well it avoids the "triangle of death" pattern with our code. But it's not necessarily that much shorter. Perhaps it's a little more cleaner on the whole though. Ultimately these code choices are up to you! Next time, we'll wrap up our current exploration of Rust by seeing how to profile our code in Rust.
This series has covered some more advanced topics in Rust. For a more in-depth introduction, check out our Rust Video Tutorial!
Appendix: Original Function
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>) -> Outcome<AuthenticatedUser, LoginError> {
let username = request.headers().get_one("username");
let password = request.headers().get_one("password");
match (username, password) {
(Some(u), Some(p)) => {
let conn_str = local_conn_string();
let maybe_user = fetch_user_by_email(&conn_str, &String::from(u));
match maybe_user {
Some(user) => {
let maybe_auth_info = fetch_auth_info_by_user_id(&conn_str, user.id);
match maybe_auth_info {
Some(auth_info) => {
let hash = hash_password(&String::from(p));
if hash == auth_info.password_hash {
Outcome::Success(AuthenticatedUser{user_id: 1})
} else {
Outcome::Failure((Status::Forbidden, LoginError::WrongPassword))
}
}
None => {
Outcome::Failure((Status::MovedPermanently, LoginError::WrongPassword))
}
}
}
None => Outcome::Failure((Status::NotFound, LoginError::UsernameDoesNotExist))
}
},
_ => Outcome::Failure((Status::BadRequest, LoginError::InvalidData))
}
}
}
Rocket Frontend: Templates and Static Assets
In the last few articles, we've been exploring the Rocket library for Rust web servers. Last time out, we tried a couple ways to add authentication to our web server. In this last Rocket-specific post, we'll explore some ideas around frontend templating. This will make it easy for you to serve HTML content to your users!
To explore the code for this article, head over to the "rocket_template" file on our Github repo! If you're still new to Rust, you might want to start with some simpler material. Take a look at our Rust Beginners Series as well!
Templating Basics
First, let's understand the basics of HTML templating. When our server serves out a webpage, we return HTML to the user for the browser to render. Consider this simple index page:
<html>
<head></head>
<body>
<p> Welcome to the site!</p>
</body>
</html>
But of course, each user should see some kind of custom content. For example, in our greeting, we might want to give the user's name. In an HTML template, we'll create a variable or sorts in our HTML, delineated by braces:
<html>
<head></head>
<body>
<p> Welcome to the site {{name}}!</p>
</body>
</html>
Now before we return the HTML to the user, we want to perform a substitution. Where we find the variable {{name}}
, we should replace it with the user's name, which our server should know.
There are many different libraries that do this, often through Javascript. But in Rust, it turns out the Rocket library has a couple easy templating integrations. One option is Tera, which was specifically designed for Rust. Another option is Handlebars, which is more native to Javascript, but also has a Rocket integration. The substitutions in this article are simple, so there's not actually much of a difference for us.
Returning a Template
So how do we configure our server to return this HTML data? To start, we have to attach a "Fairing" to our server, specifically for the Template
library. A Fairing is a server-wide piece of middleware. This is how we can allow our endpoints to return templates:
use rocket_contrib::templates::Template;
fn main() {
rocket::ignite()
.mount("/", routes![index, get_user])
.attach(Template::fairing())
.launch();
}
Now we can make our index
endpoint. It has no inputs, and it will return Rocket's Template
type.
#[get("/")]
fn index() -> Template {
...
}
We have two tasks now. First, we have to construct our context. This can be any "map-like" type with string information. We'll use a HashMap
, populating the name
value.
#[get("/")]
fn index() -> Template {
let context: HashMap<&str, &str> = [("name", "Jonathan")]
.iter().cloned().collect();
...
}
Now we have to render
our template. Let's suppose we have a "templates" directory at the root of our project. We can put the template we wrote above in the "index.hbs" file. When we call the render
function, we just give the name of our template and pass the context
!
#[get("/")]
fn index() -> Template {
let context: HashMap<&str, &str> = [("name", "Jonathan")]
.iter().cloned().collect();
Template::render("index", &context)
}
Including Static Assets
Rocket also makes it quite easy to include static assets as part of our routing system. We just have to mount
the static
route to the desired prefix when launching our server:
fn main() {
rocket::ignite()
.mount("/static", StaticFiles::from("static"))
.mount("/", routes![index, get_user])
.attach(Template::fairing())
.launch();
}
Now any request to a /static/...
endpoint will return the corresponding file in the "static" directory of our project. Suppose we have this styles.css
file:
p {
color: red;
}
We can then link to this file in our index template:
<html>
<head>
<link rel="stylesheet" type="text/css" href="static/styles.css"/>
</head>
<body>
<p> Welcome to the site {{name}}!</p>
</body>
</html>
Now when we fetch our index, we'll see that the text on the page is red!
Looping in our Database
Now for one last piece of integration with our database. Let's make a page that will show a user their basic information. This starts with a simple template:
<!-- templates/user.hbs -->
<html>
<head></head>
<body>
<p> User name: {{name}}</p>
<br>
<p> User email: {{email}}</p>
<br>
<p> User name: {{age}}</p>
</body>
</html>
We'll compose an endpoint that takes the user's ID as an input and fetches the user from the database:
#[get("/users/<uid>")]
fn get_user(uid: i32) -> Template {
let maybe_user = fetch_user_by_id(&local_conn_string(), uid);
...
}
Now we need to build our context from the user information. This will require a match
statement on the resulting user. We'll use Unknown
for the fields if the user doesn't exist.
#[get("/users/<uid>")]
fn get_user(uid: i32) -> Template {
let maybe_user = fetch_user_by_id(&local_conn_string(), uid);
let context: HashMap<&str, String> = {
match maybe_user {
Some(u) =>
[ ("name", u.name.clone())
, ("email", u.email.clone())
, ("age", u.age.to_string())
].iter().cloned().collect(),
None =>
[ ("name", String::from("Unknown"))
, ("email", String::from("Unknown"))
, ("age", String::from("Unknown"))
].iter().cloned().collect()
}
};
Template::render("user", &context)
}
And to wrap it up, we'll render
the "user" template! Now when users get directed to the page for their user ID, they'll see their information!
Conclusion
Next week, we'll go back to some of our authentication code. But we'll do so with the goal of exploring a more universal Rust idea. We'll see how functors and monads still find a home in Rust. We'll explore the functions that allow us to clean up heavy conditional code just as we could in Haskell.
For a more in-depth introduction to Rust basics, be sure to take a look at our Rust Video Tutorial!
Authentication in Rocket
Last week we enhanced our Rocket web server. We combined our server with our Diesel schema to enable a series of basic CRUD endpoints. This week, we'll continue this integration, but bring in some more cool Rocket features. We'll explore two different methods of authentication. First, we'll create a "Request Guard" to allow a form of Basic Authentication. Then we'll also explore Rocket's amazingly simple Cookies integration.
As always, you can explore the code for this series by heading to our Github repository. For this article specifically, you'll want to take a look at the rocket_auth.rs
file
If you're just starting your Rust journey, feel free to check out our Beginners Series as well!
New Data Types
To start off, let's make a few new types to help us. First, we'll need a new database table, auth_infos
, based on this struct:
#[derive(Insertable)]
pub struct AuthInfo {
pub user_id: i32,
pub password_hash: String
}
When the user creates their account, they'll provide a password. We'll store a hash of that password in our database table. Of course, you'll want to run through all the normal steps we did with Diesel to create this table. This includes having the corresponding Entity
type.
We'll also want a couple new form types to accept authentication information. First off, when we create a user, we'll now include the password in the form.
#[derive(FromForm, Deserialize)]
struct CreateInfo {
name: String,
email: String,
age: i32,
password: String
}
Second, when a user wants to login, they'll pass their username (email) and their password.
#[derive(FromForm, Deserialize)]
struct LoginInfo {
username: String,
password: String,
}
Both these types should derive FromForm
and Deserialize
so we can grab them out of "post" data. You might wonder, do we need another type to store the same information that already exists in User
and UserEntity
? It would be possible to write CreateInfo
to have a User
within it. But then we'd have to manually write the FromForm
instance. This isn't difficult, but it might be more tedious than using a new type.
Creating a User
So in the first place, we have to create our user so they're matched up with their password. This requires taking the CreateInfo
in our post request. We'll first unwrap the user fields and insert our User
object. This follows the patterns we've seen so far in this series with Diesel.
#[post("/users/create", format="json", data="<create_info>")]
fn create(db: State<String>, create_info: Json<CreateInfo>)
-> Json<i32> {
let user: User = User
{ name: create_info.name.clone(),
email: create_info.email.clone(),
age: create_info.age};
let connection = ...;
let user_entity: UserEntity = diesel::insert_into(users::table)...
…
}
Now we'll want a function for hashing our password. We'll use the SHA3 algorithm, courtesy of the rust-crypto
library:
fn hash_password(password: &String) -> String {
let mut hasher = Sha3::sha3_256();
hasher.input_str(password);
hasher.result_str()
}
We'll apply this function on the input password and attach it to the created user ID. Then we can insert the new AuthInfo
and return the created ID.
#[post("/users/create", format="json", data="<create_info>")]
fn create(db: State<String>, create_info: Json<CreateInfo>)
-> Json<i32> {
...
let user_entity: UserEntity = diesel::insert_into(users::table)...
let password_hash = hash_password(&create_info.password);
let auth_info: AuthInfo = AuthInfo
{user_id: user_entity.id, password_hash: password_hash};
let auth_info_entity: AuthInfoEntity =
diesel::insert_into(auth_infos::table)..
Json(user_entity.id)
}
Now whenever we create our user, they'll have their password attached!
Gating an Endpoint
Now that our user has a password, how do we gate endpoints on authentication? Well the first approach we can try is something like "Basic Authentication". This means that every authenticated request contains the username and the password. In our example we'll get these directly out of header elements. But in a real application you would want to double check that the request is encrypted before doing this.
But it would be tiresome to apply the logic of reading the headers in every handler. So Rocket has a powerful functionality called "Request Guards". Rocket has a special trait called FromRequest
. Whenever a particular type is an input to a handler function, it runs the from_request
function. This determines how to derive the value from the request. In our case, we'll make a wrapper type AuthenticatedUser
. This represents a user that has included their auth info in the request.
struct AuthenticatedUser {
user_id: i32
}
Now we can include this type in a handler signature. For this endpoint, we only allow a user to retrieve their data if they've logged in:
#[get("/users/my_data")]
fn login(db: State<String>, user: AuthenticatedUser)
-> Json<Option<UserEntity>> {
Json(fetch_user_by_id(&db, user.user_id))
}
Implementing the Request Trait
The trick of course is that we need to implement the FromRequest
trait! This is more complicated than it sounds! Our handler will have the ability to short-circuit the request and return an error. So let's start by specifying a couple potential login errors we can throw.
#[derive(Debug)]
enum LoginError {
InvalidData,
UsernameDoesNotExist,
WrongPassword
}
The from_request
function will take in a request and return an Outcome
. The outcome will either provide our authentication type or an error. The last bit of adornment we need on this is lifetime specifiers for the request itself and the reference to it.
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
...
}
}
Now the actual function definition involves several layers of case matching! It consists of a few different operations that have to query the request or query our database. For example, let's consider the first layer. We insist on having two headers in our request: one for the username, and one for the password. We'll use request.headers()
to check for these values. If either doesn't exist, we'll send a Failure
outcome with invalid data. Here's what that looks like:
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let username = request.headers().get_one("username");
let password = request.headers().get_one("password");
match (username, password) {
(Some(u), Some(p)) => {
...
}
_ => Outcome::Failure(
(Status::BadRequest,
LoginError::InvalidData))
}
}
}
In the main branch of the function, we'll do 3 steps:
- Find the user in our database based on their email address/username.
- Find their authentication information based on the ID
- Hash the input password and compare it to the database hash
If we are successful, then we'll return a successful outcome:
Outcome::Success(AuthenticatedUser(user_id: user.id))
The number of match
levels required makes the function definition very verbose. So we've included it at the bottom as an appendix. We know how to take such a function and write it more cleanly in Haskell using monads. In a couple weeks, we'll use this function as a case study to explore Rust's monadic abilities.
Logging In with Cookies
In most applications though, we'll won't want to include the password in the request each time. In HTTP, "Cookies" provide a way to store information about a particular user that we can track on our server.
Rocket makes this very easy with the Cookies
type! We can always include this mutable type in our requests. It works like a key-value store, where we can access certain information with a key like "user_id"
. Since we're storing auth information, we'll also want to make sure it's encoded, or "private". So we'll use these functions:
add_private(...)
get_private(...)
remove_private(...)
Let's start with a "login" endpoint. This will take our LoginInfo
object as its post data, but we'll also have the Cookies
input:
#[post("/users/login", format="json", data="<login_info>")]
fn login_post(db: State<String>, login_info: Json<LoginInfo>, mut cookies: Cookies) -> Json<Option<i32>> {
...
}
First we have to make sure a user of that name exists in the database:
#[post("/users/login", format="json", data="<login_info>")]
fn login_post(
db: State<String>,
login_info: Json<LoginInfo>,
mut cookies: Cookies)
-> Json<Option<i32>> {
let maybe_user = fetch_user_by_email(&db, &login_info.username);
match maybe_user {
Some(user) => {
...
}
}
None => Json(None)
}
}
Then we have to get their auth info again. We'll hash the password and compare it. If we're successful, then we'll add the user's ID as a cookie. If not, we'll return None
.
#[post("/users/login", format="json", data="<login_info>")]
fn login_post(
db: State<String>,
login_info: Json<LoginInfo>,
mut cookies: Cookies)
-> Json<Option<i32>> {
let maybe_user = fetch_user_by_email(&db, &login_info.username);
match maybe_user {
Some(user) => {
let maybe_auth = fetch_auth_info_by_user_id(&db, user.id);
match maybe_auth {
Some(auth_info) => {
let hash = hash_password(&login_info.password);
if hash == auth_info.password_hash {
cookies.add_private(Cookie::new(
"user_id", u ser.id.to_string()));
Json(Some(user.id))
} else {
Json(None)
}
}
None => Json(None)
}
}
None => Json(None)
}
}
A more robust solution of course would loop in some error behavior instead of returning None
.
Using Cookies
Using our cookie now is pretty easy. Let's make a separate "fetch user" endpoint using our cookies. It will take the Cookies
object and the user ID as inputs. The first order of business is to retrieve the user_id
cookie and verify it exists.
#[get("/users/cookies/<uid>")]
fn fetch_special(db: State<String>, uid: i32, mut cookies: Cookies)
-> Json<Option<UserEntity>> {
let logged_in_user = cookies.get_private("user_id");
match logged_in_user {
Some(c) => {
...
},
None => Json(None)
}
}
Now we need to parse the string value as a user ID and compare it to the value from the endpoint. If they're a match, we just fetch the user's information from our database!
#[get("/users/cookies/<uid>")]
fn fetch_special(db: State<String>, uid: i32, mut cookies: Cookies)
-> Json<Option<UserEntity>> {
let logged_in_user = cookies.get_private("user_id");
match logged_in_user {
Some(c) => {
let logged_in_uid = c.value().parse::<i32>().unwrap();
if logged_in_uid == uid {
Json(fetch_user_by_id(&db, uid))
} else {
Json(None)
}
},
None => Json(None)
}
And when we're done, we can also post a "logout" request that will remove the cookie!
#[post("/users/logout", format="json")]
fn logout(mut cookies: Cookies) -> () {
cookies.remove_private(Cookie::named("user_id"));
}
Conclusion
We've got one more article on Rocket before checking out some different Rust concepts. So far, we've only dealt with the backend part of our API. Next week, we'll investigate how we can use Rocket to send templated HTML files and other static web content!
Maybe you're more experienced with Haskell but still need a bit of an introduction to Rust. We've got some other materials for you! Watch our Rust Video Tutorial for an in-depth look at the basics of the language!
Appendix: From Request Function
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>) -> Outcome<AuthenticatedUser, LoginError> {
let username = request.headers().get_one("username");
let password = request.headers().get_one("password");
match (username, password) {
(Some(u), Some(p)) => {
let conn_str = local_conn_string();
let maybe_user = fetch_user_by_email(&conn_str, &String::from(u));
match maybe_user {
Some(user) => {
let maybe_auth_info = fetch_auth_info_by_user_id(&conn_str, user.id);
match maybe_auth_info {
Some(auth_info) => {
let hash = hash_password(&String::from(p));
if hash == auth_info.password_hash {
Outcome::Success(AuthenticatedUser{user_id: 1})
} else {
Outcome::Failure((Status::Forbidden, LoginError::WrongPassword))
}
}
None => {
Outcome::Failure((Status::MovedPermanently, LoginError::WrongPassword))
}
}
}
None => Outcome::Failure((Status::NotFound, LoginError::UsernameDoesNotExist))
}
},
_ => Outcome::Failure((Status::BadRequest, LoginError::InvalidData))
}
}
}
Joining Forces: An Integrated Rust Web Server
We've now explored a couple different libraries for some production tasks in Rust. A couple weeks ago, we used Diesel to create an ORM for some database types. And then last week, we used Rocket to make a basic web server to respond to basic requests. This week, we'll put these two ideas together! We'll use some more advanced functionality from Rocket to make some CRUD endpoints for our database type. Take a look at the code on Github here!
If you've never written any Rust, you should start with the basics though! Take a look at our Rust Beginners Series!
Database State and Instances
Our first order of business is connecting to the database from our handler functions. There are some direct integrations you can check out between Rocket, Diesel, and other libraries. These can provide clever ways to add a connection argument to any handler.
But for now we're going to keep things simple. We'll re-generate the PgConnection
within each endpoint. We'll maintain a "stateful" connection string to ensure they all use the same database.
Our Rocket server can "manage" different state elements. Suppose we have a function that gives us our database string. We can pass that to our server at initialization time.
fn local_conn_string() -> String {...}
fn main() {
rocket::ignite()
.mount("/", routes![...])
.manage(local_conn_string())
.launch();
}
Now we can access this String
from any of our endpoints by giving an input the State<String>
type. This allows us to create our connection:
#[get(...)]
fn fetch_all_users(database_url: State<String>) -> ... {
let connection = pgConnection.establish(&database_url)
.expect("Error connecting to database!");
...
}
Note: We can't use the PgConnection
itself because stateful types need to be thread safe.
So any other of our endpoints can now access the same database. Before we start writing these, we need a couple things first though. Let's recall that for our Diesel ORM we made a User
type and a UserEntity
type. The first is for inserting/creating, and the second is for querying. We need to add some instances to those types so they are compatible with our endpoints. We want to have JSON instances (Serialize, Deserialize), as well as FromForm
for our User
type:
#[derive(Insertable, Deserialize, Serialize, FromForm)]
#[table_name="users"]
pub struct User {
...
}
#[derive(Queryable, Serialize)]
pub struct UserEntity {
...
}
Now let's see how we get these types from our endpoints!
Retrieving Users
We'll start with a simple endpoint to fetch all the different users in our database. This will take no inputs, except our stateful database URL. It will return a vector of UserEntity
objects, wrapped in Json
.
#[get("/users/all")]
fn fetch_all_users(database_url: State<String>)
-> Json<Vec<UserEntity>> {
...
}
Now all we need to do is connect to our database and run the query function. We can make our users vector into a Json
object by wrapping with Json()
. The Serialize
instance lets us satisfy the Responder
trait for the return value.
#[get("/users/all")]
fn fetch_all_users(database_url: State<String>)
-> Json<Vec<UserEntity>> {
let connection = PgConnection::establish(&database_url)
.expect("Error connecting to database!");
Json(users.load::<UserEntity>(&connection)
.expect("Error loading users"))
}
Now for getting individual users. Once again, we'll wrap a response in JSON. But this time we'll return an optional, single, user. We'll use a dynamic capture parameter in the URL for the User ID.
#[get("/users/<uid>")]
fn fetch_user(database_url: State<String>, uid: i32)
-> Option<Json<UserEntity>> {
let connection = ...;
...
}
We'll want to filter on the users table by the ID. This will give us a list of different results. We want to specify this vector as mutable. Why? In the end, we want to return the first user. But Rust's memory rules mean we must either copy or move this item. And we don't want to move a single item from the vector without moving the whole vector. So we'll remove the head from the vector entirely, which requires mutability.
#[get("/users/<uid>")]
fn fetch_user(database_url: State<String>, uid: i32)
-> Option<Json<UserEntity>> {
let connection = ...;
use rust_web::schema::users::dsl::*;
let mut users_by_id: Vec<UserEntity> =
users.filter(id.eq(uid))
.load::<UserEntity>(&connection)
.expect("Error loading users");
...
}
Now we can do our case analysis. If the list is empty, we return None
. Otherwise, we'll remove the user from the vector and wrap it.
#[get("/users/<uid>")]
fn fetch_user(database_url: State<String>, uid: i32) -> Option<Json<UserEntity>> {
let connection = ...;
use rust_web::schema::users::dsl::*;
let mut users_by_id: Vec<UserEntity> =
users.filter(id.eq(uid))
.load::<UserEntity>(&connection)
.expect("Error loading users");
if users_by_id.len() == 0 {
None
} else {
let first_user = users_by_id.remove(0);
Some(Json(first_user))
}
}
Create/Update/Delete
Hopefully you can see the pattern now! Our queries are all pretty simple. So our endpoints all follow a similar pattern. Connect to the database, run the query and wrap the result. We can follow this process for the remaining three endpoints in a basic CRUD setup. Let's start with "Create":
#[post("/users/create", format="application/json", data = "<user>")]
fn create_user(database_url: State<String>, user: Json<User>)
-> Json<i32> {
let connection = ...;
let user_entity: UserEntity = diesel::insert_into(users::table)
.values(&*user)
.get_result(&connection).expect("Error saving user");
Json(user_entity.id)
}
As we discussed last week, we can use data
together with Json
to specify the form data in our post request. We de-reference the user with *
to get it out of the JSON wrapper. Then we insert the user and wrap its ID to send back.
Deleting a user is simple as well. It has the same dynamic path as fetching a user. We just make a delete
call on our database instead.
#[delete("/users/<uid>")]
fn delete_user(database_url: State<String>, uid: i32) -> Json<i32> {
let connection = ...;
use rust_web::schema::users::dsl::*;
diesel::delete(users.filter(id.eq(uid)))
.execute(&connection)
.expect("Error deleting user");
Json(uid)
}
Updating is the last endpoint, which takes a put
request. The endpoint mechanics are just like our other endpoints. We use a dynamic path component to get the user's ID, and then provide a User
body with the updated field values. The only trick is that we need to expand our Diesel knowledge a bit. We'll use update
and set
to change individual fields on an item.
#[put("/users/<uid>/update", format="json", data="<user>")]
fn update_user(
database_url: State<String>, uid: i32, user: Json<User>)
-> Json<UserEntity> {
let connection = ...;
use rust_web::schema::users::dsl::*;
let updated_user: UserEntity =
diesel::update(users.filter(id.eq(uid)))
.set((name.eq(&user.name),
email.eq(&user.email),
age.eq(user.age)))
.get_result::<UserEntity>(&connection)
.expect("Error updating user");
Json(updated_user)
}
The other gotcha is that we need to use references (&
) for the string fields in the input user. But now we can add these routes to our server, and it will manipulate our database as desired!
Conclusion
There are still lots of things we could improve here. For example, we're still using .expect
in many places. From the perspective of a web server, we should be catching these issues and wrapping them with "Err 500". Rocket also provides some good mechanics for fixing that. Next week though, we'll pivot to another server problem that Rocket solves adeptly: authentication. We should restrict certain endpoints to particular users. Rust provides an authentication scheme that is neatly encoded in the type system!
For a more in-depth introduction to Rust, watch our Rust Video Tutorial. It will take you through a lot of key skills like understanding memory and using Cargo!
Rocket: Web Servers in Rust!
Welcome back to our series on building simple apps in Rust. Last week, we explored the Diesel library which gave us an ORM for database interaction. For the next few weeks, we'll be trying out the Rocket library, which makes it quick and easy to build a web server in Rust! This is comparable to the Servant library in Haskell, which we've explored before.
This week, we'll be working on the basic building blocks of using this library. The reference code for this article is available here on Github!
Rust combines some of the neat functional ideas of Haskell with some more recognizable syntax from C++. To learn more of the basics, take a look at our Rust Beginners Series!
Our First Route
To begin, let's make a simple "hello world" endpoint for our server. We don't specify a full API definition all at once like we do with Servant. But we do use a special macro before the endpoint function. This macro describes the route's method and its path.
#[get("/hello")]
fn index() -> String {
String::from("Hello, world!")
}
So our macro tells us this is a "GET" endpoint and that the path is /hello
. Then our function specifies a String
as the return value. We can, of course, have different types of return values, which we'll explore those more as the series goes on.
Launching Our Server
Now this endpoint is useless until we can actually run and launch our server. To do this, we start by creating an object of type Rocket
with the ignite()
function.
fn main() {
let server: Rocket = rocket::ignite();
...
}
We can then modify our server by "mounting" the routes we want. The mount
function takes a base URL path and a list of routes, as generated by the routes
macro. This function returns us a modified server:
fn main() {
let server: Rocket = rocket::ignite();
let server2: Rocket = server.mount("/", routes![index]);
}
Rather than create multiple server objects, we'll just compose these different functions. Then to launch our server, we use launch
on the final object!
fn main() {
rocket::ignite().mount("/", routes![index]).launch();
}
And now our server will respond when we ping it at localhost:8000/hello
! We could, of course, use a different base path. We could even assign different routes to different bases!
fn main() {
rocket::ignite().mount("/api", routes![index]).launch();
}
Now it will respond at /api/hello
.
Query Parameters
Naturally, most endpoints need inputs to be useful. There are a few different ways we can do this. The first is to use path components. In Servant, we call these CaptureParams
. With Rocket, we'll format our URL to have brackets around the variables we want to capture. Then we can assigned them with a basic type in our endpoint function:
#[get("/math/<name>")]
fn hello(name: &RawStr) -> String {
format!("Hello, {}!", name.as_str())
}
We can use any type that satisfies the FromParam
trait, including a RawStr
. This is a Rocket specific type wrapping string-like data in a raw format. With these strings, we might want to apply some sanitization processes on our data. We can also use basic numeric types, like i32
.
#[get("/math/<first>/<second>")]
fn add(first: i32, second: i32) -> String {
String::from(format!("{}", first + second))
}
This endpoint will now return "11" when we ping /math/5/6
.
We can also use "query parameters", which all go at the end of the URL. These need the FromFormValue
trait, rather than FromParam
. But once again, RawStr
and basic numbers work fine.
#[get("/math?<first>&<second>)]
fn multiply(first: i32, second: i32) {
String::from(format!("{}", first * second)
}
Now we'll get "30" when we ping /math?5&6
.
Post Requests
The last major input type we'll deal with is post request data. Suppose we have a basic user type:
struct User {
name: String,
email: String,
age: i32
}
We'll want to derive various classes for it so we can use it within endpoints. From the Rust "Serde" library we'll want Deserialize
and Serialize
so we can make JSON elements out of it. Then we'll also want FromForm
to use it as post request data.
#[derive(FromForm, Deserialize, Serialize)]
struct User {
...
}
Now we can make our endpoint, but we'll have to specify the "format" as JSON and the "data" as using our "user" type.
#[post("/users/create", format="json", data="<user>")]
fn create_user(user: Json<User>) -> String {
...
}
We need to provide the Json
wrapper for our input type, but we can use it as though it's a normal User
. For now, we'll just return a string echoing the user's information back to us. Don't forget to add each new endpoint to the routes
macro in your server definition!
#[post("/users/create", format="json", data="<user>")]
fn create_user(user: Json<User>) -> String {
String::from(format!(
"Created user: {} {} {}", user.name, user.email, user.age))
}
Conclusion
Next time, we'll explore making a more systematic CRUD server. We'll add database integration and see some other tricks for serializing data and maintaining state. Then we'll explore more advanced topics like authentication, static files, and templating!
If you're going to be building a web application in Rust, you'd better have a solid foundation! Watch our Rust Video Tutorial to get an in-depth introduction!
Diesel: A Rust-y ORM
Last week on Monday Morning Haskell we took our first step into some real world tasks with Rust. We explored the simple Rust Postgres library to connect to a database and run some queries. This week we're going to use Diesel, a library with some cool ORM capabilities. It's a bit like the Haskell library Persistent, which you can explore more in our Real World Haskell Series.
For a more in-depth look at the code for this article, you should take a look at our Github Repository! You'll want to look at the files referenced below and also at the executable here.
Diesel CLI
Our first step is to add Diesel as a dependency in our program. We briefly discussed Cargo "features" in last week's article. Diesel has separate features for each backend you might use. So we'll specify "postgres". Once again, we'll also use a special feature for the chrono
library so we can use timestamps in our database.
[[dependencies]]
diesel={version="1.4.4", features=["postgres", "chrono"]}
But there's more! Diesel comes with a CLI that helps us manage our database migrations. It also will generate some of our schema code. Just as we can install binaries with Stack using stack install
, we can do the same with Cargo. We only want to specify the features we want. Otherwise it will crash if we don't have the other databases installed.
>> cargo install diesel_cli --no-default-features --features postgres
Now we can start using the program to setup our project to generate our migrations. We begin with this command.
>> diesel setup
This creates a couple different items in our project directory. First, we have a "migrations" folder, where we'll put some SQL code. Then we also get a schema.rs
file in our src
directory. Diesel will automatically generate code for us in this file. Let's see how!
Migrations and Schemas
When using Persistent in Haskell, we defined our basic types in a single Schema file using a special template language. We could run migrations on our whole schema programmatically, without our own SQL. But it is difficult to track more complex migrations as your schema evolves.
Diesel is a bit different. Unfortunately, we have to write our own SQL. But, we'll do so in a way that it's easy to take more granular actions on our table. Diesel will then generate a schema file for us. But we'll still need some extra work to get the Rust types we'll need. To start though, let's use Diesel to generate our first migration. This migration will create our "users" table.
>> diesel migration generate create_users
This creates a new folder within our "migrations" directory for this "create_users" migration. It has two files, up.sql
and down.sql
. We start by populating the up.sql
file to specify the SQL we need to run the migration.
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER NOT NULL
)
Then we also want the down.sql
file to contain SQL that reverses the migration.
DROP TABLE users CASCADE;
Once we've written these, we can run our migration!
>> diesel migration run
We can then undo the migration, running the code in down.sql
with this command:
>> diesel migration redo
The result of running our migration is that Diesel populates the schema.rs
file. This file uses the table
macro that generates helpful types and trait instances for us. We'll use this a bit when incorporating the table into our code.
table! {
users (id) {
id -> Int4,
name -> Text,
email -> Text,
age -> Int4,
}
}
While we're at it, let's make one more migration to add an articles table.
-- migrations/create_articles/up.sql
CREATE TABLE articles (
id SERIAL PRIMARY KEY,
title TEXT NOT NULL,
body TEXT NOT NULL,
published_at TIMESTAMP WITH TIME ZONE NOT NULL,
author_id INTEGER REFERENCES users(id) NOT NULL
)
-- migrations/create_articles/down.sql
DROP TABLE articles;
Then we can once again use diesel migration run
.
Model Types
Now, while Diesel will generate a lot of useful code for us, we still need to do some work on our own. We have to create our own structs for the data types to take advantage of the instances we get. With Persistent, we got these for free. Persistent also used a wrapper Entity
type, which attached a Key
to our actual data.
Diesel doesn't have the notion of an entity. We have to manually make two different types, one with the database key and one without. For the "Entity" type which has the key, we'll derive the "Queryable" class. Then we can use Diesel's functions to select items from the table.
#[derive(Queryable)]
pub struct UserEntity {
pub id: i32
pub name: String,
pub email: String,
pub age: i32
}
We then have to declare a separate type that implements "Insertable". This doesn't have the database key, since we don't know the key before inserting the item. This should be a copy of our entity type, but without the key field. We use a second macro to tie it to the users
table.
#[derive(Insertable)]
#[table_name="users"]
pub struct User {
pub name: String,
pub email: String,
pub age: i32
}
Note that in the case of our foreign key type, we'll use a normal integer for our column reference. In Persistent we would have a special Key
type. We lose some of the semantic meaning of this field by doing this. But it can help keep more of our code separate from this specific library.
Making Queries
Now that we have our models in place, we can start using them to write queries. First, we need to make a database connection using the establish
function. Rather than using the ?
syntax, we'll use .expect
to unwrap our results in this article. This is less safe, but a little easier to work with.
fn create_connection() -> PgConnection {
let database_url = "postgres://postgres:postgres@localhost/rust_db";
PgConnection::establish(&database_url)
.expect("Error Connecting to database")
}
fn main() {
let connection: PgConnection = create_connection();
...
}
Let's start now with insertion. Of course, we begin by creating one of our "Insertable" User
items. We can then start writing an insertion query with the Diesel function insert_into
.
Diesel's query functions are composable. We add different elements to the query until it is complete. With an insertion, we use values
combined with the item we want to insert. Then, we call get_result
with our connection. The result of an insertion is our "Entity" type.
fn create_user(conn: &PgConnection) -> UserEntity {
let u = User
{ name = "James".to_string()
, email: "james@test.com".to_string()
, age: 26};
diesel::insert_into(users::table).values(&u)
.get_result(conn).expect("Error creating user!")
}
Selecting Items
Selecting items is a bit more complicated. Diesel generates a dsl
module for each of our types. This allows us to use each field name as a value within "filters" and orderings. Let's suppose we want to fetch all the articles written by a particular user. We'll start our query on the articles
table and call filter
to start building our query. We can then add a constraint on the author_id
field.
fn fetch_articles(conn: &PgConnection, uid: i32) -> Vec<ArticleEntity> {
use rust_web::schema::articles::dsl::*;
articles.filter(author_id.eq(uid))
...
We can also add an ordering to our query. Notice again how these functions compose. We also have to specify the return type we want when using the load
function to complete our select query. The main case is to return the full entity. This is like SELECT * FROM
in SQL lingo. Applying load
will give us a vector of these items.
fn fetch_articles(conn: &PgConnection, uid: i32) -> Vec<ArticleEntity> {
use rust_web::schema::articles::dsl::*;
articles.filter(author_id.eq(uid))
.order(title)
.load::<ArticleEntity>(conn)
.expect("Error loading articles!")
}
But we can also specify particular fields that we want to return. We'll see this in the final example, where our result type is a vector of tuples. This last query will be a join between our two tables. We start with users
and apply the inner_join
function.
fn fetch_all_names_and_titles(conn: &PgConnection) -> Vec<(String, String)> {
use rust_web::schema::users::dsl::*;
use rust_web::schema::articles::dsl::*;
users.inner_join(...
}
Then we join it to the articles table on the particular ID field. Because both of our tables have id
fields, we have to namespace it to specify the user's ID field.
fn fetch_all_names_and_titles(conn: &PgConnection) -> Vec<(String, String)> {
use rust_web::schema::users::dsl::*;
use rust_web::schema::articles::dsl::*;
users.inner_join(
articles.on(author_id.eq(rust_web::schema::users::dsl::id)))...
}
Finally, we load
our query to get the results. But notice, we use select
and only ask for the name
of the User and the title
of the article. This gives us our final values, so that each element is a tuple of two strings.
fn fetch_all_names_and_titles(conn: &PgConnection) -> Vec<(String, String)> {
use rust_web::schema::users::dsl::*;
use rust_web::schema::articles::dsl::*;
users.inner_join(
articles.on(author_id.eq(rust_web::schema::users::dsl::id)))
.select((name, title)).load(conn).expect("Error on join query!")
}
Conclusion
For my part, I prefer the functionality provided by Persistent in Haskell. But Diesel's method of providing a separate CLI to handle migrations is very cool as well. And it's good to see more sophisticated functionality in this relatively new language.
If you're still new to Rust, we have some more beginner-related material. Read our Rust Beginners Series or better yet, watch our Rust Video Tutorial!
Basic Postgres Data in Rust
For our next few articles, we're going to be exploring some more advanced concepts in Rust. Specifically, we'll be looking at parallel ideas from our Real World Haskell Series. In these first couple weeks, we'll be exploring how to connect Rust and a Postgres database. To start, we'll use the Rust Postgres library. This will help us create a basic database connection so we can make simple queries. You can see all the code for this article in action by looking at our RustWeb repository. Specifically, you'll want to check out the file pg_basic.rs
.
If you're new to Rust, we have a couple beginner resources for you to start out with. You can read our Rust Beginners Series to get a basic introduction to the language. Or for some more in-depth explanations, you can watch our Rust Video Tutorial!!
Creating Tables
We'll start off by making a client object to connect to our database. This uses a query string like we would with any Postgres library.
let conn_string = "host=localhost port=5432 user=postgres";
let mut client : Client = Client::connect(conn_string, NoTls)?;
Note that the connect
function generally returns a Result<Client, Error>
. In Haskell, we would write this as Either Error Client
. By using ?
at the end of our call, we can immediately unwrap the Client
. The caveat on this is that it only compiles if the whole function returns some kind of Result<..., Error>
. This is an interesting monadic behavior Rust gives us. Pretty much all our functions in this article will use this ?
behavior.
Now that we have a client, we can use it to run queries. The catch is that we have to know the Raw SQL ourselves. For example, here's how we can create a table to store some users:
client.batch_execute("\
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER NOT NULL
)
")?;
Inserting with Interpolation
A raw query like that with no real result is the simplest operation we can perform. But, any non-trivial program will require us to customize the queries programmatically. To do this we'll need to interpolate values into the middle of our queries. We can do this with execute
(as opposed to batch_execute
).
Let's try creating a user. As with batch_execute
, we need a query string. This time, the query string will contain values like $1
, $2
that we'll fill in with variables. We'll provide these variables with a list of references. Here's what it looks like with a user:
let name = "James";
let email = "james@test.com";
let age = 26;
client.execute(
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3)",
&[&name, &email, &age],
)?;
Again, we're using a raw query string. All the values we interpolate must implement the specific class postgres_types::ToSql
. We'll see this a bit later.
Fetching Results
The last main type of query we can perform is to fetch our results. We can use our client to call the query
function, which returns a vector of Row
objects:
for row: Row in client.query("SELECT * FROM users"), &[])? {
...
}
For more complicated SELECT
statements we would interpolate parameters, as with insertion above. The Row
has different Columns
for accessing the data. But in our case it's a little easier to use get
and the index to access the different fields. Like our Raw SQL calls, this is unsafe in a couple ways. If we use an out of bounds index, we'll get a crash. And if we try to cast to the wrong data type, we'll also run into problems.
for row: Row in client.query("SELECT * FROM users"), &[])? {
let id: i32 = row.get(0);
let name: &str = row.get(1);
let email: &str = row.get(2);
let age: i32 = row.get(3);
...
}
We could then use these individual values to populate whatever data types we wanted on our end.
Joining Tables
If we want to link two tables together, of course we'll also have to know how to do this with Raw SQL. For example, we can make our articles table:
client.batch_execute("\
CREATE TABLE articles (
id SERIAL PRIMARY KEY,
title TEXT NOT NULL,
body TEXT NOT NULL,
published_at TIMESTAMP WITH TIME ZONE NOT NULL,
author_id INTEGER REFERENCES users(id)
)
")?;
Then, after retrieving a user's ID, we can insert an article written by that user.
for row: Row in client.query("SELECT * FROM users"), &[])? {
let id: i32 = row.get(0);
let title: &str = "A Great Article!";
let body: &str = "You should share this with friends.";
let cur_time: DateTime<Utc> = Utc::now();
client.execute(
"INSERT INTO articles (title, body, published_at, author_id) VALUES ($1, $2, $3, $4)",
&[&title, &body, &cur_time, &id]
)?;
}
One of this tricky parts is that this won't compile if you only use the basic postgres
dependency in Rust! There isn't a native ToSql
instance for the DateTime<Utc>
type. However, Rust dependencies can have specific "features". This concept doesn't really exist in Haskell, except through extra packages. You'll need to specify the with-chrono
feature for the version of the chrono
library you use. This feature, or sub-dependency contains the necessary ToSql
instance. Here's what the structure looks like in our Cargo.toml
:
[dependencies]
chrono="0.4"
postgres={version="0.17.3", features=["with-chrono-0_4"]}
After this, our code will compile!
Runtime Problems
Now there are lots of reasons we wouldn't want to use a library like this in a formal project. One of the big principles of Rust (and Haskell) is catching errors at compile time. And writing out functions with lots of raw SQL like this makes our program very prone to runtime errors. I encountered several of these as I was writing this small program! At one point, I started writing the SELECT
query and absentmindedly forgot to complete it until I ran my program!
At another point, I couldn't decide what timestamp format to use in Postgres. I went back and forth between using a TIMESTAMP
or just an INTEGER
for the published_at
field. I needed to coordinate the SQL for both the table creation query and the fetching query. I often managed to change one but not the other, resulting in annoying runtime errors. I finally discovered I needed TIMESTAMP WITH TIME ZONE
and not merely TIMESTAMP
. This was a rather painful process with this setup.
Conclusion
Next week, we'll explore Diesel, a library that lets us use schemas to catch more of these issues at compile time. The framework is more comparable to Persistent in Haskell. It gives us an ORM (Object Relational Mapping) so that we don't have to write raw SQL. This approach is much more suited to languages like Haskell and Rust!
To try out tasks like this in Haskell, take a look at our Production Checklist! It includes a couple different libraries for interacting with databases using ORMs.
Preparing for Rust!
Next week, we're going to change gears a bit and start some interesting projects with Rust! Towards the end of last year, we dabbled a bit with Rust and explored some of the basics of the language. In our next series of blog articles, we're going to take a deep dive into some more advanced concepts.
We'll explore several different Rust libraries in various topics. We'll consider data serialization, web servers and databases, among other. We'll build a couple small apps, and compare the results to our earlier work with Haskell.
To get ready for this series, you should brush up on your Rust basics! To help, we've wrapped up our Rust content into a permanent series on the Beginners page! Here's an overview of that series:
Part 1: Basic Syntax
We start out by learning about Rust's syntax. We'll see quite a few differences to Haskell. But there are also some similarities in unexpected places.
Part 2: Memory Management
One of the major things that sets Rust apart from other languages is how it manages memory. In the second part, we'll learn a bit about how Rust's memory system works.
Part 3: Data Types
In the third part of the series, we'll explore how to make our own data types in Rust. We'll see that Rust borrows some of Haskell's neat ideas!
Part 4: Cargo Package Manager
Cargo is Rust's equivalent of Stack and Cabal. It will be our package and dependency manager. In part 4, we see how to make basic Rust projects using Cargo.
Part 5: Lifetimes and Collections
In the final part, we'll look at some more advanced collection types in Rust. Because of Rust's memory model, we'll need some special rules for handling items in collections. This will lead us to the idea of lifetimes.
If you prefer video content, our Rust Video Tutorial also provides a solid foundation. It goes through all the topics in this series, starting from installation. Either way, stay tuned for new blog content, starting next week!
Summer Course Sale!
This week we have some exciting news! Back in March, we opened our Practical Haskell course for enrollment. The first round of students has had a chance to go through the course. So we're now opening it up for general enrollment!
This course goes through some more practical concepts and libraries you might use on a real world project. Here's a sneak peak at some of the skills you'll learn:
- Making a web server with Persistent and Servant
- Deploying a Haskell project using Heroku and Circle CI
- Making a web frontend with Elm, and connecting it to the Haskell backend
- Using Monad Transformers and Free Effects to organize our application
- Test driven development in Haskell
As a special bonus, for this week only, both of our courses are on sale, $100 off their normal prices! So if you're not ready for Practical Haskell, you can take a look at Haskell From Scratch. With that said, if you buy either course now, you'll have access to all the materials indefinitely! Prices will go back to normal after this Sunday, so head to the course pages now!
Next week, we'll start getting back into the swing of things by reviewing some of our Rust basics!
Mid-Summer Break, Open AI Gym Series!
We're taking a little bit of a mid-summer break from new content here at MMH. But we have done some extra work in organizing the site! Last week we wrapped up our series on Haskell and the Open AI Gym. We've now added that series as a permanent fixture on the advanced section of the page!
Here's a quick summary of the series:
Part 1: Frozen Lake Primer
The first part introduces the Open AI framework and goes through the Frozen lake example. It presents the core concept of an environment.
Part 2: Frozen Lake in Haskell
In the second part, we write a basic version of Frozen Lake in Haskell.
Part 3: Blackjack
Next, we expand on our knowledge of games and environments to write a second game. This one based on casino Blackjack, and it will start to show us common elements in games.
Part 4: Q-Learning
Now we start getting into the ideas of reinforcement learning. We'll explore Q-Learning, one of the simplest techniques in this field. We'll apply this approach to both of our games.
Part 5: Generalized Environments
Now that we've seen the learning process in action, we can start generalizing our games. We'll create an abstract notion of what an Environment
is. Just as Python has a specific API for their games, so will we! In true Haskell fashion, we'll represent this API with a type family!
Part 6: Q-Learning with Tensors in Python
In part 6, we'll take our Q-learning process a step further by using TensorFlow. We'll see how we can learn a more general function than we had before. We'll start this process in Python, where the mathematical operations are more clear.
Part 7: Q-Learning with Tensors in Haskell
Once we know how Q-Learning works with Python, we'll apply these techniques in Haskell as well! Once you get here, you'd better be ready to use your Haskell TensorFlow skills!
Part 8: Rendering with Gloss
In the final part of the series, we'll see how we can use the Gloss library to render our Haskell games!
You can take a look at the series summary page for more details!
In a couple weeks, we'll be back, this time with some fresh Rust content! Take a look at our Rust Video Tutorial to get a headstart on that!
Rendering Frozen Lake with Gloss!
We've spent the last few weeks exploring some of the ideas in the Open AI Gym framework. We made a couple games, generalized them, and applied some machine learning techniques. When it comes to rendering our games though, we're still relying on a very basic command line text format.
But if we want to design agents for more visually appealing games, we'll need a better solution! Last year, we spent quite a lot of time learning about the Gloss library. This library makes it easy to create simple games and render them using OpenGL. Take a look at this article for a summary of our work there and some links to the basics.
In this article, we'll explore how we can draw some connections between Gloss and our Open AI Gym work. We'll see how we can take the functions we've already written and use them within Gloss!
Gloss Basics
The key entrypoint for a Gloss game is the play
function. At its core is the world
type parameter, which we'll define for ourselves later.
play :: Display -> Color -> Int
-> world
-> (world -> Picture)
-> (Event -> world -> world)
-> (Float -> world -> world)
-> IO ()
We won't go into the first three parameters. But the rest are important. The first is our initial world state. The second is our rendering function. It creates a Picture
for the current state. Then comes an "event handler". This takes user input events and updates the world based on the actions. Finally there is the update function. This changes the world based on the passage of time, rather than specific user inputs.
This structure should sound familiar, because it's a lot like our Open AI environments! The initial world is like the "reset" function. Then both systems have a "render" function. And the update functions are like our stepEnv
function.
The main difference we'll see is that Gloss's functions work in a pure way. Recall our "environment" functions use the "State" monad. Let's explore this some more.
Re-Writing Environment Functions
Let's take a look at the basic form of these environment functions, in the Frozen Lake context:
resetEnv :: (Monad m) => StateT FrozenLakeEnvironment m Observation
stepEnv :: (Monad m) =>
Action -> StateT FrozenLakeEnvironment m (Observation, Double, Bool)
renderEnv :: (MonadIO m) => StateT FrozenLakeEnvironment m ()
These all use State
. This makes it easy to chain them together. But if we look at the implementations, a lot of them don't really need to use State
. They tend to unwrap the environment at the start with get
, calculate new results, and then have a final put
call.
This means we can rewrite them to fit more within Gloss's pure structure! We'll ignore rendering, since that will be very different. But here are some alternate type signatures:
resetEnv' :: FrozenLakeEnvironment -> FrozenLakeEnvironment
stepEnv' :: Action -> FrozenLakeEnvironment
-> (FrozenLakeEnvironment, Double, Bool)
We'll exclude Observation
as an output, since the environment contains that through currentObservation
. The implementation for each of these looks like the original. Here's what resetting looks like:
resetEnv' :: FrozenLakeEnvironment -> FrozenLakeEnvironment
resetEnv' fle = fle
{ currentObservation = 0
, previousAction = Nothing
}
Now for stepping our environment forward:
stepEnv' :: Action -> FrozenLakeEnvironment -> (FrozenLakeEnvironment, Double, Bool)
stepEnv' act fle = (finalEnv, reward, done)
where
currentObs = currentObservation fle
(slipRoll, gen') = randomR (0.0, 1.0) (randomGenerator fle)
allLegalMoves = legalMoves currentObs (dimens fle)
numMoves = length allLegalMoves - 1
(randomMoveIndex, finalGen) = randomR (0, numMoves) gen'
newObservation = ... -- Random move, or apply the action
(done, reward) = case (grid fle) A.! newObservation of
Goal -> (True, 1.0)
Hole -> (True, 0.0)
_ -> (False, 0.0)
finalEnv = fle
{ currentObservation = newObservation
, randomGenerator = finalGen
, previousAction = Just act
}
What's even better is that we can now rewrite our original State
functions using these!
resetEnv :: (Monad m) => StateT FrozenLakeEnvironment m Observation
resetEnv = do
modify resetEnv'
gets currentObservation
stepEnv :: (Monad m) =>
Action -> StateT FrozenLakeEnvironment m (Observation, Double, Bool)
stepEnv act = do
fle <- get
let (finalEnv, reward, done) = stepEnv' act fle
put finalEnv
return (currentObservation finalEnv, reward, done)
Implementing Gloss
Now let's see how this ties in with Gloss. It might be tempting to use our Environment
as the world
type. But it can be useful to attach other information as well. For one example, we can also include the current GameResult
, telling us if we've won, lost, or if the game is still going.
data GameResult =
GameInProgress |
GameWon |
GameLost
deriving (Show, Eq)
data World = World
{ environment :: FrozenLakeEnvironment
, gameResult :: GameResult
}
Now we can start building the other pieces of our game. There aren't really any "time" updates in our game, except to update the result based on our location:
updateWorldTime :: Float -> World -> World
updateWorldTime _ w = case tile of
Goal -> World fle GameWon
Hole -> World fle GameLost
_ -> w
where
fle = environment w
obs = currentObservation fle
tile = grid fle A.! obs
When it comes to handling inputs, we need to start with the case of restarting the game. When the game isn't InProgress
, only the "enter" button matters. This resets everything, using resetEnv'
:
handleInputs :: Event -> World -> World
handleInputs event w
| gameResult w /= GameInProgress = case event of
(EventKey (SpecialKey KeyEnter) Down _ _) ->
World (resetEnv' fle) GameInProgress
_ -> w
...
Now we handle each directional input key. We'll make a helper function at the bottom that does the business of calling stepEnv'
.
handleInputs :: Event -> World -> World
handleInputs event w
| gameResult w /= GameInProgress = case event of
(EventKey (SpecialKey KeyEnter) Down _ _) ->
World (resetEnv' fle) GameInProgress
| otherwise = case event of
(EventKey (SpecialKey KeyUp) Down _ _) ->
w {environment = finalEnv MoveUp }
(EventKey (SpecialKey KeyRight) Down _ _) ->
w {environment = finalEnv MoveRight }
(EventKey (SpecialKey KeyDown) Down _ _) ->
w {environment = finalEnv MoveDown }
(EventKey (SpecialKey KeyLeft) Down _ _) ->
w {environment = finalEnv MoveLeft }
_ -> w
where
fle = environment w
finalEnv action =
let (fe, _, _) = stepEnv' action fle
in fe
The last step is rendering the environment with a draw
function. This just requires a working knowledge of constructing the Picture
type in Gloss. It's a little tedious, so I've included the full implementation as an appendix at the bottom. We can then combine all these pieces like so:
main :: IO ()
main = do
env <- basicEnv
play windowDisplay white 20
(World env GameInProgress)
drawEnvironment
handleInputs
updateWorldTime
After we have all these pieces, we can run our game, moving our player around to reach the green tile while avoiding the black tiles!
Conclusion
With a little more plumbing, it would be possible to combine this with the rest of our "Environment" work. There are some definite challenges. Our current environment setup doesn't have a "time update" function. Combining machine learning with Gloss rendering would also be interesting. This is the end of our Open Gym series for now, but I'll definitely be working on this project more in the future! Next week we'll have a summary and review what we've learned!
Take a look at our Github repository to see all the code we wrote in this series! The code for this article is on the gloss
branch. And don't forget to Subscribe to Monday Morning Haskell to get our monthly newsletter!
Appendix: Rendering Frozen Lake
A lot of numbers here are hard-coded for a 4x4 grid, where each cell is 100x100. Notice particularly that we have a text message if we've won or lost.
windowDisplay :: Display
windowDisplay = InWindow "Window" (400, 400) (10, 10)
drawEnvironment :: World -> Picture
drawEnvironment world
| gameResult world == GameWon = Translate (-150) 0 $ Scale 0.12 0.25
(Text "You've won! Press enter to restart!")
| gameResult world == GameLost = Translate (-150) 0 $ Scale 0.12 0.25
(Text "You've lost :( Press enter to restart.")
| otherwise = Pictures [tiles, playerMarker]
where
observationToCoords :: Word -> (Word, Word)
observationToCoords w = quotRem w 4
renderTile :: (Word, TileType) -> Picture
renderTile (obs, tileType ) =
let (centerX, centerY) = rowColToCoords . observationToCoords $ obs
color' = case tileType of
Goal -> green
Hole -> black
_ -> blue
in Translate centerX centerY (Color color' (Polygon [(-50, -50), (-50, 50), (50, 50), (50, -50)]))
tiles = Pictures $ map renderTile (A.assocs (grid . environment $ world))
(px, py) = rowColToCoords . observationToCoords $ (currentObservation . environment $ world)
playerMarker = translate px py (Color red (ThickCircle 10 3))
rowColToCoords :: (Word, Word) -> (Float, Float)
rowColToCoords (row, col) = (100 * (fromIntegral col - 1.5), 100 * (1.5 - fromIntegral row))
Training our Agent with Haskell!
In the previous part of the series, we used the ideas of Q-Learning together with TensorFlow. We got a more general solution to our agent that didn't need a table for every state of the game.
This week, we'll take the final step and implement this TensorFlow approach in Haskell. We'll see how to integrate this library with our existing Environment
system. It works out quite smoothly, with a nice separation between our TensorFlow logic and our normal environment logic!
This article requires a working knowledge of the Haskell TensorFlow integration. If you're new to this, you should download our Guide showing how to work with this framework. You can also read our original Machine Learning Series for some more details! In particular, the second part will go through the basics of tensors.
Building Our TF Model
The first thing we want to do is construct a "model". This model type will store three items. The first will be the tensor for the weights we have. Then the second two will be functions in the TensorFlow Session
monad. The first function will provide scores for the different moves in a position, so we can choose our move. The second will allow us to train the model and update the weights.
data Model = Model
{ weightsT :: Variable Float
, chooseActionStep :: TensorData Float -> Session (Vector Float)
, learnStep :: TensorData Float -> TensorData Float -> Session ()
}
The input for choosing an action is our world observation state, converted to a Float
and put in a size 16-vector. The result will be 4 floating point values for the scores. Then our learning step will take in the observation as well as a set of 4 values. These are the "target" values we're training our model on.
We can construct our model within the Session
monad. In the first part of this process we define our weights and use them to determine the score of each move (results).
createModel :: Session Model
createModel = do
-- Choose Action
inputs <- placeholder (Shape [1, 16])
weights <- truncatedNormal (vector [16, 4]) >>= initializedVariable
let results = inputs `matMul` readValue weights
returnedOutputs <- render results
...
Now we make our "trainer". Our "loss" function is the reduced, squared difference between our results and the "target" outputs. We'll use the adam
optimizer to learn values for our weights to minimize this loss.
createModel :: Session Model
createModel = do
-- Choose Action
...
-- Train Nextwork
(nextOutputs :: Tensor Value Float) <- placeholder (Shape [4, 1])
let (diff :: Tensor Build Float) = nextOutputs `sub` results
let (loss :: Tensor Build Float) = reduceSum (diff `mul` diff)
trainer_ <- minimizeWith adam loss [weights]
...
Finally, we wrap these tensors into functions we can call using runWithFeeds
. Recall that each feed
provides us with a way to fill in one of our placeholder tensors.
createModel :: Session Model
createModel = do
-- Choose Action
...
-- Train Network
...
-- Create Model
let chooseStep = \inputFeed ->
runWithFeeds [feed inputs inputFeed] returnedOutputs
let trainStep = \inputFeed nextOutputFeed ->
runWithFeeds [ feed inputs inputFeed
, feed nextOutputs nextOutputFeed
]
trainer_
return $ Model weights chooseStep trainStep
Our model now wraps all the different tensor operations we need! All we have to do is provide it with the correct TensorData
. To see how that works, let's start integrating with our EnvironmentMonad
!
Integrating With Environment
Our model's functions exist within the TensorFlow monad Session
. So how then, do we integrate this with our existing Environment
code? The answer is, of course, to construct a new monad! This monad will wrap Session
, while still giving us our FrozenLakeEnvironment
! We'll keep the environment within a State
, but we'll also keep a reference to our Model
.
newtype FrozenLake a = FrozenLake
(StateT (FrozenLakeEnvironment, Model) Session a)
deriving (Functor, Applicative, Monad)
instance (MonadState FrozenLakeEnvironment) FrozenLake where
get = FrozenLake (fst <$> get)
put fle = FrozenLake $ do
(_, model) <- get
put (fle, model)
Now we can start implementing the actual EnvironmentMonad
instance. Most of our existing types and functions will work with trivial modification. The only real change is that runEnv
will need to run a TensorFlow session and create the model. Then it can use evalStateT
.
instance EnvironmentMonad FrozenLake where
type (Observation FrozenLake) = FrozenLakeObservation
type (Action FrozenLake) = FrozenLakeAction
type (EnvironmentState FrozenLake) = FrozenLakeEnvironment
baseEnv = basicEnv
currentObservation = currentObs <$> get
resetEnv = resetFrozenLake
stepEnv = stepFrozenLake
runEnv env (FrozenLake action) = runSession $ do
model <- createModel
evalStateT action (env, model)
This is all we need to define the first class. But, with TensorFlow, our environment is only useful if we use the tensor model! This means we need to fill in LearningEnvironment
as well. This has two functions, chooseActionBrain
and learnEnv
using our tensors. Let's see how that works.
Choosing an Action
Choosing an action is straightforward. We'll once again start with the same format for sometimes choosing a random move:
chooseActionTensor :: FrozenLake FrozenLakeAction
chooseActionTensor = FrozenLake $ do
(fle, model) <- get
let (exploreRoll, gen') = randomR (0.0, 1.0) (randomGenerator fle)
if exploreRoll < flExplorationRate fle
then do
let (actionRoll, gen'') = Rand.randomR (0, 3) gen'
put $ (fle { randomGenerator = gen'' }, model)
return (toEnum actionRoll)
else do
...
As in Python, we'll need to convert an observation to a tensor type. This time, we'll create TensorData
. This type wraps a vector, and our input should have the size 1x16. It has the format of a oneHot
tensor. But it's easier to make this a pure function, rather than using a TensorFlow monad.
obsToTensor :: FrozenLakeObservation -> TensorData Float
obsToTensor obs = encodeTensorData (Shape [1, 16]) (V.fromList asList)
where
asList = replicate (fromIntegral obs) 0.0 ++
[1.0] ++
replicate (fromIntegral (15 - obs)) 0.0
Since we've already defined our chooseAction
step within the model, it's easy to use this! We convert the current observation, get the result values, and then pick the best index!
chooseActionTensor :: FrozenLake FrozenLakeAction
chooseActionTensor = FrozenLake $ do
(fle, model) <- get
-- Random move
...
else do
let obs1 = currentObs fle
let obs1Data = obsToTensor obs1
-- Use model!
results <- lift ((chooseActionStep model) obs1Data)
let bestMoveIndex = V.maxIndex results
put $ (fle { randomGenerator = gen' }, model)
return (toEnum bestMoveIndex)
Learning From the Environment
One unfortunate part of our current design is that we have to repeat some work in our learning function. To learn from our action, we need to use all the values, not just the chosen action. So to start our learning function, we'll call chooseActionStep
again. This time we'll get the best index AND the max score.
learnTensor ::
FrozenLakeObservation -> FrozenLakeObservation ->
Reward -> FrozenLakeAction ->
FrozenLake ()
learnTensor obs1 obs2 (Reward reward) action = FrozenLake $ do
model <- snd <$> get
let obs1Data = obsToTensor obs1
-- Use the model!
results <- lift ((chooseActionStep model) obs1Data)
let (bestMoveIndex, maxScore) =
(V.maxIndex results, V.maximum results)
...
We can now get our "target" values by substituting in the reward and max score at the proper index. Then we convert the second observation to a tensor, and we have all our inputs to call our training step!
learnTensor ::
FrozenLakeObservation -> FrozenLakeObservation ->
Reward -> FrozenLakeAction ->
FrozenLake ()
learnTensor obs1 obs2 (Reward reward) action = FrozenLake $ do
...
let (bestMoveIndex, maxScore) =
(V.maxIndex results, V.maximum results)
let targetActionValues = results V.//
[(bestMoveIndex, double2Float reward + (gamma * maxScore))]
let obs2Data = obsToTensor obs2
let targetActionData = encodeTensorData
(Shape [4, 1])
targetActionValues
-- Use the model!
lift $ (learnStep model) obs2Data targetActionData
where
gamma = 0.81
Using these two functions, we can now fill in our LearningEnvironment
class!
instance LearningEnvironment FrozenLake where
chooseActionBrain = chooseActionTensor
learnEnv = learnTensor
-- Same as before
explorationRate = ..
reduceExploration = ...
We'll then be able to run this code just as we would our other Q-learning examples!
Conclusion
This wraps up the machine learning part of this series. We'll have one more article about Open Gym next week. We'll compare our current setup and the Gloss library. Gloss offers much more extensive possibilities for rendering our game and accepting input. So using it would expand the range of games we could play!
We'll definitely continue to expand on the Open Gym concept in the future! Expect a more formal approach to this at some point! For now, take a look at our Github repository for this series! This article's code is on the tensorflow
branch!
Q-Learning with Tensors
In our last article we finished refactoring our Gym code to use a type family. This would make it much easier to add new games to our framework in the future. We're now in the closing stages of this series on AI and agent development. This week we're going to incorporate TensorFlow and perform some more advanced techniques.
We've used Q-Learning to train some agents to play simple games like Frozen Lake and Blackjack. Our existing approach uses an exhaustive table from observations to expected rewards. But in most games we won't be able to construct such an exhaustive table. The observation space will be too large, or it will be continuous. So in this article, we're going to explore how to use TensorFlow to build a more generic function we can learn. We'll start this process in Python, where there's a bit less overhead.
Next up, we'll be using TensorFlow with our Haskell code. We'll explore an alternative form of our FrozenLake
monad using this approach. To make sure you're ready for it, download our Haskell TensorFlow Guide.
A Q-Function
Our goal here will be to make a more general Q-Function, instead of using a table. A Q-Function provides another way of writing our chooseAction
function. With the table approach, each of the 16 possible observations had 4 scores, one for each of the actions we can take. To choose an action, we just take the index with the highest score.
We now want to incorporate a simple neural network for chooseAction
. In our example, this network will consist of a single matrix of weights. The input to our network will be a vector of size 16. This vector will have all zeroes, except for the index of the current observation, which will be 1. Then the output of the network will be a vector of size 4. These will give the scores for each move from that observation. So our "weights" will have size 16x4.
So one useful helper function we can write already will be to convert an observation to an input tensor. This will make use of the identity matrix.
def obs_to_tensor(obs):
return np.identity(16)[obs:obs+1]
Building the Graph
We can now go ahead and start building our tensor graph. We'll start with the part that makes moves from an observation. For this quick Python script, we'll let the tensors live in the global namespace.
import gym
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
env = gym.make('FrozenLake-v0')
inputs = tf.placeholder(shape=[1,16], dtype=tf.float32)
weights = tf.Variable(tf.random_uniform([16, 4], 0, 0.01))
output = tf.matmul(inputs, weights)
prediction = tf.argmax(output, 1)
Each time we make a move, we'll pass the current observation tensor as the input placeholder. Then we multiply it by the weights to get scores for each different output action. Our final "prediction" is the output index with the highest weight. Notice how we initialize our network with random weights. This helps prevent our network from getting stuck early on.
We can use these tensors to construct our choose_action
function. This will, of course take the current observation as an input. But it will also take an epsilon
value for the random move probability. We use sess.run
to run our prediction and output tensors. If we choose a random move instead, we'll replace the actual "action" with a sample from the action space.
def choose_action(input_obs, epsilon):
action, all_outputs = sess.run(
[prediction, output],
feed_dict={inputs: obs_to_tensor(input_obs)})
if np.random.rand(1) < epsilon:
action[0] = env.action_space.sample()
return action, all_outputs
The Learning Process
The first part of our graph tells us how to make moves, but we also need to update our weights so the network gets better! To do this, we'll add a few more tensors.
next_output = tf.placeholder(shape=[1,4], dtype=tf.float32)
loss = tf.reduce_sum(tf.square(next_output - output))
trainer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
update_model = trainer.minimize(loss)
init = tf.initialize_all_variables()
Let's go through these one-by-one. We need to take an extra input for the target values, which incorporate the "next" state of the game. We want the values we get in the original state to be closer to those! So our "loss" function is the squared difference of our "current" output and the "target" output. Then we create a "trainer" that minimizes the loss function. Because our weights are the "variable" in the system, they'll get updated to minimize this loss.
We can use this section group of tensors to construct our "learning" function.
def learn_env(current_obs, next_obs, reward, action, all_outputs):
gamma = 0.81
_, all_next_outputs = choose_action(next_obs, 0.0)
next_max = np.max(all_next_outputs)
target_outputs = all_outputs
target_outputs[0, action[0]] = reward + gamma * next_max
sess.run(
[update_model, weights],
feed_dict={inputs: obs_to_tensor(current_obs),
next_output: target_outputs})
We start by choosing an action from the "next" position (without randomness). We get the largest value from that choice. We use this and the reward to inform our "target" of what the current input weights should be. In other words, taking our action should give us the reward and the best value we would get from the next position. Then we update our model!
Playing the Game
Now all that's left is to play out the game! This looks a lot like code from previous parts, so we won't go into too much depth. The key section is in the middle of the loop. We choose our next action, use it to step the environment, and use the reward to learn.
rewards_list = []
with tf.Session() as sess:
sess.run(init)
epsilon = 0.9
decay_rate = 0.9
num_episodes = 10000
for i in range(num_episodes):
# Reset environment and get first new observation
current_obs = env.reset()
sum_rewards = 0
done = False
num_steps = 0
while num_steps < 100:
num_steps += 1
# Choose, Step, Learn!
action, all_outputs = choose_action(current_obs, epsilon)
next_obs, reward, done, _ = env.step(action[0])
learn_env(current_obs, next_obs, reward, action, all_outputs)
sum_rewards += reward
current_obs = next_obs
if done == True:
if i % 100 == 99:
epsilon *= decay_rate
break
rewards_list.append(sum_rewards)
Our results won't be quite as good as the table approach. Using a tensor function allows our system to be a lot more general. But the consequence of this is that the results aren't stable. We could, of course, improve the results by using more advanced algorithms. But we'll get into that another time!
Conclusion
Now that we know the core ideas behind using tensors for Q-Learning, it's time to do this in Haskell. Next week, we'll do a refresher on how Haskell operates together with Tensor Flow. We'll see how we can work these ideas into our existing Environment
framework.
Refactored Game Play!
Last week, we implemented Q-learning for our Blackjack game. We found the solution looked a lot like Frozen Lake for the most part. So we created a new class EnvironmentMonad
to combine the steps these games have in common. This week, we'll see a full implementation of that class. Our goal is a couple generic gameLoop
functions we can use for different modes of our game.
As always, the code for this article is on our Github repository! You'll mainly want to explore any of the source files with Environment
in their name.
Expanding our Environment
Last time, we put together a basic idea of what a generic environment could look like. We made a couple separate "sub-classes" as well, for rendering and learning.
class (Monad m) => EnvironmentMonad m where
type Observation m :: *
type Action m :: *
resetEnv :: m (Observation m)
stepEnv :: (Action m) -> m (Observation m, Reward, Bool)
class (MonadIO m, EnvironmentMonad m) => RenderableEnvironment m where
renderEnv :: m ()
class (EnvironmentMonad m) => LearningEnvironment m where
learnEnv ::
(Observation m) -> (Observation m) -> Reward -> (Action m) -> m ()
There are still a couple extra pieces we can add that will make these classes more complete. One thing we're missing here is a concrete expression of our state. This makes it difficult to run our environments from normal code. So let's add a new type to the family for our "Environment" type, as well as a function to "run" that environment. We'll also want a generic way to get the current observation.
class (Monad m) => EnvironmentMonad m where
type Observation m :: *
type Action m :: *
type EnvironmentState m :: *
runEnv :: (EnvironmentState m) -> m a -> IO a
currentObservation :: m (Observation m)
resetEnv :: m (Observation m)
stepEnv :: (Action m) -> m (Observation m, Reward, Bool)
Forcing run
to use IO
is more restrictive than we'd like. In the future we might explore how to get our environment to wrap a monad parameter to fix this.
We can also add a couple items to our LearningEnvironment
for the exploration rate. This way, we don't need to do anything concrete to affect the learning process. We'll also make the function for choosing an action is a specific part of the environment.
class (EnvironmentMonad m) => LearningEnvironment m where
learnEnv ::
(Observation m) -> (Observation m) -> Reward -> (Action m) -> m ()
chooseActionBrain :: m (Action m)
explorationRate :: m Double
reduceExploration :: Double -> Double -> m ()
Game Loops
In previous iterations, we had gameLoop
functions for each of our different environments. We can now write these in a totally generic way! Here's a simple loop that plays the game once and produces a result:
gameLoop :: (EnvironmentMonad m) =>
m (Action m) -> m (Observation m, Reward)
gameLoop chooseAction = do
newAction <- chooseAction
(newObs, reward, done) <- stepEnv newAction
if done
then return (newObs, reward)
else gameLoop chooseAction
If we want to render the game between moves, we add a single renderEnv
call before selecting the move. We also need an extra IO
constraint and to render it before returning the final result.
gameRenderLoop :: (RenderableEnvironment m) =>
m (Action m) -> m (Observation m, Reward)
gameRenderLoop chooseAction = do
renderEnv
newAction <- chooseAction
(newObs, reward, done) <- stepEnv newAction
if done
then renderEnv >> return (newObs, reward)
else gameRenderLoop chooseAction
Finally, there are a couple different loops we can write for a learning environment. We can have a generic loop for one iteration of the game. Notice how we rely on the class function chooseActionBrain
. This means we don't need such a function as a parameter.
gameLearningLoop :: (LearningEnvironment m) =>
m (Observation m, Reward)
gameLearningLoop = do
oldObs <- currentObservation
newAction <- chooseActionBrain
(newObs, reward, done) <- stepEnv newAction
learnEnv oldObs newObs reward newAction
if done
then return (newObs, reward)
else gameLearningLoop
Then we can make another loop that runs many learning iterations. We reduce the exploration rate at a reasonable interval.
gameLearningIterations :: (LearningEnvironment m) => m [Reward]
gameLearningIterations = forM [1..numEpisodes] $ \i -> do
resetEnv
when (i `mod` 100 == 99) $ do
reduceExploration decayRate minEpsilon
(_, reward) <- gameLearningLoop
return reward
where
numEpisodes = 10000
decayRate = 0.9
minEpsilon = 0.01
Concrete Implementations
Now we want to see how we actually implement these classes for our types. We'll show the examples for FrozenLake
but it's an identical process for Blackjack
. We start by defining the monad type as a wrapper over our existing state.
newtype FrozenLake a = FrozenLake (StateT FrozenLakeEnvironment IO a)
deriving (Functor, Applicative, Monad)
We'll want to make a State
instance for our monads over the environment type. This will make it easier to port over our existing code. We'll also need a MonadIO
instance to help with rendering.
instance (MonadState FrozenLakeEnvironment) FrozenLake where
get = FrozenLake get
put fle = FrozenLake $ put fle
instance MonadIO FrozenLake where
liftIO act = FrozenLake (liftIO act)
Then we want to change our function signatures to live in the desired monad. We can pretty much leave the functions themselves untouched.
resetFrozenLake :: FrozenLake FrozenLakeObservation
stepFrozenLake ::
FrozenLakeAction -> FrozenLake (FrozenLakeObservation, Reward, Bool)
renderFrozenLake :: FrozenLake ()
Finally, we make the actual instance for the class. The only thing we haven't defined yet is the runEnv
function. But this is a simple wrapper for evalStateT
.
instance EnvironmentMonad FrozenLake where
type (Observation FrozenLake) = FrozenLakeObservation
type (Action FrozenLake) = FrozenLakeAction
type (EnvironmentState FrozenLake) = FrozenLakeEnvironment
baseEnv = basicEnv
runEnv env (FrozenLake action) = evalStateT action env
currentObservation = FrozenLake (currentObs <$> get)
resetEnv = resetFrozenLake
stepEnv = stepFrozenLake
instance RenderableEnvironment FrozenLake where
renderEnv = renderFrozenLake
There's a bit more we could do. We could now separate the "brain" portions of the environment without any issues. We wouldn't need to keep the Q-Table and the exploration rate in the state. This would improve our encapsulation. We could also make our underlying monads more generic.
Playing the Game
Now, playing our game is simple! We get our basic environment, reset it, and call our loop function! This code will let us play one iteration of Frozen Lake, using our own input:
main :: IO ()
main = do
(env :: FrozenLakeEnvironment) <- basicEnv
_ <- runEnv env action
putStrLn "Done!"
where
action = do
resetEnv
(gameRenderLoop chooseActionUser
:: FrozenLake (FrozenLakeObservation, Reward))
Once again, we can make this code work for Blackjack with a simple name substitution.
We can also make this work with our Q-learning code as well. We start with a simple instance for LearningEnvironment
.
instance LearningEnvironment FrozenLake where
learnEnv = learnQTable
chooseActionBrain = chooseActionQTable
explorationRate = flExplorationRate <$> get
reduceExploration decayRate minEpsilon = do
fle <- get
let e = flExplorationRate fle
let newE = max minEpsilon (e * decayRate)
put $ fle { flExplorationRate = newE }
And now we use gameLearningIterations
instead of gameRenderLoop
!
main :: IO ()
main = do
(env :: FrozenLakeEnvironment) <- basicEnv
_ <- runEnv env action
putStrLn "Done!"
where
action = do
resetEnv
(gameLearningIterations :: FrozenLake [Reward])
Conclusion
We're still pulling in two "extra" pieces besides the environment class itself. We still have specific implementations for basicEnv
and action choosing. We could try to abstract these behind the class as well. There would be generic functions for choosing the action as a human and choosing at random. This would force us to make the action space more general as well.
But for now, it's time to explore some more interesting learning algorithms. For our current Q-learning approach, we make a table with an entry for every possible game state. This doesn't scale to games with large or continuous observation spaces! Next week, we'll see how TensorFlow allows us to learn a Q function instead of a direct table.
We'll start in Python, but soon enough we'll be using TensorFlow in Haskell. Take a look at our guide for help getting everything installed!