Unit Tests and Benchmarks in Rust
For a couple months now, we've focused on some specific libraries you can use in Rust for web development. But we shouldn't lose sight of some other core language skills and mechanics. Whenever you write code, you should be able to show first that it works, and second that it works efficiently. If you're going to build a larger Rust app, you should also know a bit about unit testing and benchmarking. This week, we'll take a couple simple sorting algorithms as our examples to learn these skills.
As always, you can take a look at the code for this article on our Github Repo for the series. You can find this week's code specifically in sorters.rs
! For a more basic introduction to Rust, be sure to check out our Rust Beginners Series!
Insertion Sort
We'll start out this article by implementing insertion sort. This is one of the simpler sorting algorithms, which is rather inefficient. We'll perform this sort "in place". This means our function won't return a value. Rather, we'll pass a mutable reference to our vector so we can manipulate its items. To help out, we'll also define a swap
function to change two elements around that same reference:
pub fn swap(numbers: &mut Vec<i32>, i: usize, j: usize) {
let temp = numbers[i];
numbers[i] = numbers[j];
numbers[j] = temp;
}
pub fn insertion_sorter(numbers: &mut Vec<i32>) {
...
}
At its core, insertion sort is a pretty simple algorithm. We maintain the invariant that the "left" part of the array is always sorted. (At the start, with only 1 element, this is clearly true). Then we loop through the array and "absorb" the next element into our sorted part. To absorb the element, we'll loop backwards through our sorted portion. Each time we find a larger element, we switch their places. When we finally encounter a smaller element, we know the left side is once again sorted.
pub fn insertion_sorter(numbers: &mut Vec<i32>) {
for i in 1..numbers.len() {
let mut j = i;
while j > 0 && numbers[j-1] > numbers[j] {
swap(numbers, j, j - 1);
j = j - 1;
}
}
}
Testing
Our algorithm is simple enough. But how do we know it works? The obvious answer is to write some unit tests for it. Rust is actually a bit different from Haskell and most other languages in the canonical approach to unit tests. Most of the time, you'll make a separate test directory. But Rust encourages you to write unit tests in the same file as the function definition. We do this by having a section at the bottom of our file specifically for tests. We delineate a test function with the test
macro:
[#test]
fn test_insertion_sort() {
...
}
To keep things simple, we'll define a random vector of 100 integers and pass it to our function. We'll use assert
to verify that each number is smaller than the next one after it.
#[test]
fn test_insertion_sort() {
let mut numbers: Vec<i32> = random_vector(100);
insertion_sorter(&mut numbers);
for i in 0..(numbers.len() - 1) {
assert!(numbers[i] <= numbers[i + 1]);
}
}
When we run the cargo test
command, Cargo will automatically detect that we have a test suite in this file and run it.
running 1 test...
test sorter::test_insertion_sort ... ok
Benchmarking
So we know our code works, but how quickly does it work? When you want to check the performance of your code, you need to establish benchmarks. These are like test suites except that they're meant to give out the average time it takes to perform a task.
Just as we had a test
macro for making test suites, we can use the bench
macro for benchmarks. Each of these takes a mutable Bencher
object as an argument. To record some code, we'll call iter
on that object and pass a closure that will run our function.
#[bench]
fn bench_insertion_sort_100_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(100);
insertion_sorter(&mut numbers)
});
}
We can then run the benchmark with cargo bench
.
running 2 tests
test sorter::test_insertion_sort ... ignored
test sorter::bench_insertion_sort_100_ints ... bench: 6,537 ns
/iter (+/- 1,541)
So on average, it took about 6ms to sort 100 numbers. On its own, this number doesn't tell us much. But we can get a more clear idea for the runtime of our algorithm by looking at benchmarks of different sizes. Suppose we make lists of 1000 and 10000:
#[bench]
fn bench_insertion_sort_1000_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(1000);
insertion_sorter(&mut numbers)
});
}
#[bench]
fn bench_insertion_sort_10000_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(10000);
insertion_sorter(&mut numbers)
});
}
Now when we run the benchmark, we can compare the results of these different runs:
running 4 tests
test sorter::test_insertion_sort ... ignored
test sorter::bench_insertion_sort_10000_ints ... bench: 65,716,130 ns
/iter (+/- 11,193,188)
test sorter::bench_insertion_sort_1000_ints ... bench: 612,373 ns
/iter (+/- 124,732)
test sorter::bench_insertion_sort_100_ints ... bench: 12,032 ns
/iter (+/- 904)
We see that when we increase the problem size by a factor of 10, we increase the runtime by a factor of nearly 100! This confirms for us that our simple insertion sort has an asymptotic runtime of O(n^2)
, which is not very good.
Quick Sort
There are many ways to sort more efficiently! Let's try our hand at quicksort. For this algorithm, we first "partition" our array. We'll choose a pivot value, and then move all the numbers smaller than the pivot to the left of the array, and all the greater numbers to the right. The upshot is that we know our pivot element is now in the correct final spot!
Here's what the partition algorithm looks like. It works on a specific sub-segment of our vector, indicated by start
and end
. We initially move the pivot element to the back, and then loop through the other elements of the array. The i
index tracks where our pivot will end up. Each time we encounter a smaller number, we increment it. At the very end we swap our pivot element back into its place, and return its final index.
pub fn partition(
numbers: &mut Vec<i32>,
start: usize,
end: usize,
partition: usize)
-> usize {
let pivot_element = numbers[partition];
swap(numbers, partition, end - 1);
let mut i = start;
for j in start..(end - 1) {
if numbers[j] < pivot_element {
swap(numbers, i, j);
i = i + 1;
}
}
swap(numbers, i, end - 1);
i
}
So to finish sorting, we'll set up a recursive helper that, again, functions on a sub-segment of the array. We'll choose a random element and partition by it:
pub fn quick_sorter_helper(
numbers: &mut Vec<i32>, start: usize, end: usize) {
if start >= end {
return;
}
let mut rng = thread_rng();
let initial_partition = rng.gen_range(start, end);
let partition_index =
partition(numbers, start, end, initial_partition);
...
}
Now that we've partitioned, all that's left to do is recursively sort each side of the partition! Our main API function will call this helper with the full size of the array.
pub fn quick_sorter_helper(
numbers: &mut Vec<i32>, start: usize, end: usize) {
if start >= end {
return;
}
let mut rng = thread_rng();
let initial_partition = rng.gen_range(start, end);
let partition_index =
partition(numbers, start, end, initial_partition);
quick_sorter_helper(numbers, start, partition_index);
quick_sorter_helper(numbers, partition_index + 1, end);
}
pub fn quick_sorter(numbers: &mut Vec<i32>) {
quick_sorter_helper(numbers, 0, numbers.len());
}
Now that we've got this function, let's add tests and benchmarks for it:
#[test]
fn test_quick_sort() {
let mut numbers: Vec<i32> = random_vector(100);
quick_sorter(&mut numbers);
for i in 0..(numbers.len() - 1) {
assert!(numbers[i] <= numbers[i + 1]);
}
}
#[bench]
fn bench_quick_sort_100_ints(b: &mut Bencher) {
b.iter(|| {
let mut numbers: Vec<i32> = random_vector(100);
quick_sorter(&mut numbers)
});
}
// Same kind of benchmarks for 1000, 10000, 100000
Then we can run our benchmarks and see our results:
running 9 tests
test sorter::test_insertion_sort ... ignored
test sorter::test_quick_sort ... ignored
test sorter::bench_insertion_sort_10000_ints ... bench: 65,130,880 ns
/iter (+/- 49,548,187)
test sorter::bench_insertion_sort_1000_ints ... bench: 312,300 ns
/iter (+/- 243,337)
test sorter::bench_insertion_sort_100_ints ... bench: 6,159 ns
/iter (+/- 4,139)
test sorter::bench_quick_sort_100000_ints ... bench: 14,292,660 ns
/iter (+/- 5,815,870)
test sorter::bench_quick_sort_10000_ints ... bench: 1,263,985 ns
/iter (+/- 622,788)
test sorter::bench_quick_sort_1000_ints ... bench: 105,443 ns
/iter (+/- 65,812)
test sorter::bench_quick_sort_100_ints ... bench: 9,259 ns
/iter (+/- 3,882)
Quicksort does much better on the larger values, as expected! We can discern that the times seem to only go up by a factor of around 10. It's difficult to determine that the true runtime is actually O(n log n)
. But we can clearly see that we're much closer to linear time!
Conclusion
That's all for this intermediate series on Rust! Next week, we'll summarize the skills we learned over the course of these couple months in Rust. Then we'll look ahead to our next series of topics, including some totally new kinds of content!
Don't forget! If you've never programmed in Rust before, our Rust Video Tutorial provides an in-depth introduction to the basics!
Cleaning our Rust with Monadic Functions
A couple weeks ago we explored how to add authentication to a Rocket Rust server. This involved writing a from_request
function that was very messy. You can see the original version of that function as an appendix at the bottom. But this week, we're going to try to improve that function! We'll explore functions like map
and and_then
in Rust. These can help us write cleaner code using similar ideas to functors and monads in Haskell.
For more details on this code, take a look at our Github Repo! For this article, you should look at rocket_auth_monads.rs
. For a simpler introduction to Rust, take a look at our Rust Beginners Series!
Closures and Mapping
First, let's talk a bit about Rust's equivalent to fmap
and functors. Suppose we have a simple option wrapper and a "doubling" function:
fn double(x: f64) -> {
2.0 * x
}
fn main() -> () {
let x: Option<f64> = Some(5.0);
...
}
We'd like to pass our x
value to the double
function, but it's wrapped in the Option
type. A logical thing to do would be to return None
if the input is None
, and otherwise apply the function and re-wrap in Some
. In Haskell, we describe this behavior with the Functor
class. Rust's approach has some similarities and some differences.
Instead of Functor
, Rust has a trait Iterable
. An iterable type contains any number of items of its wrapped type. And map
is one of the functions we can call on iterable types. As in Haskell, we provide a function that transforms the underlying items. Here's how we can apply our simple example with an Option
:
fn main() -> () {
let x: Option<f64> = Some(5.0);
let y: Option<f64> = x.map(double);
}
One notable difference from Haskell is that map
is a member function of the iterator type. In Haskell of course, there's no such thing as member functions, so fmap
exists on its own.
In Haskell, we can use lambda expressions as arguments to higher order functions. In Rust, it's the same, but they're referred to as closures instead. The syntax is rather different as well. We capture the particular parameters within bars, and then provide a brace-delimited code-block. Here's a simple example:
fn main() -> () {
let x: Option<f64> = Some(5.0);
let y: Option<f64> = x.map(|x| {2.0 * x});
}
Type annotations are also possible (and sometimes necessary) when specifying the closure. Unlike Haskell, we provide these on the same line as the definition:
fn main() -> () {
let x: Option<f64> = Some(5.0);
let y: Option<f64> = x.map(|x: f64| -> f64 {2.0 * x});
}
And Then…
Now using map
is all well and good, but our authentication example involved using the result of one effectful call in the next effect. As most Haskellers can tell you, this is a job for monads and not merely functors. We can capture some of the same effects of monads with the and_then
function in Rust. This works a lot like the bind operator (>>=)
in Haskell. It also takes an input function. And this function takes a pure input but produces an effectful output.
Here's how we apply it with Option
. We start with a safe_square_root
function that produces None
when it's input is negative. Then we can take our original Option
and use and_then
to use the square root function.
fn safe_square_root(x: f64) -> Option<f64> {
if x < 0.0 {
None
} else {
Some(x.sqrt())
}
}
fn main() -> () {
let x: Option<f64> = Some(5.0);
x.and_then(safe_square_root);
}
Converting to Outcomes
Now let's switch gears to our authentication example. Our final result type wasn't Option
. Some intermediate results used this. But in the end, we wanted an Outcome
. So to help us on our way, let's write a simple function to convert our options into outcomes. We'll have to provide the extra information of what the failure result should be. This is the status_error
parameter.
fn option_to_outcome<R>(
result: Option<R>,
status_error: (Status, LoginError))
-> Outcome<R, LoginError> {
match result {
Some(r) => Outcome::Success(r),
None => Outcome::Failure(status_error)
}
}
Now let's start our refactoring process. To begin, let's examine the retrieval of our username and password from the headers. We'll make a separate function for this. This should return an Outcome
, where the success value is a tuple of two strings. We'll start by defining our failure outcome, a tuple of a status and our LoginError
.
fn read_auth_from_headers(headers: &HeaderMap)
-> Outcome<(String, String), LoginError> {
let fail = (Status::BadRequest, LoginError::InvalidData);
...
}
We'll first retrieve the username out of the headers. Recall that this operation returns an Option
. So we can convert it to an Outcome
using our function. We can then use and_then
with a closure taking the unwrapped username.
fn read_auth_from_headers(headers: &HeaderMap)
-> Outcome<(String, String), LoginError> {
let fail = (Status::BadRequest, LoginError::InvalidData);
option_to_outcome(headers.get_one("username"), fail.clone())
.and_then(|u| -> Outcome<(String, String), LoginError> {
...
})
}
We can then do the same thing with the password field. When we've successfully unwrapped both fields, we can return our final Success
outcome.
fn read_auth_from_headers(headers: &HeaderMap)
-> Outcome<(String, String), LoginError> {
let fail = (Status::BadRequest, LoginError::InvalidData);
option_to_outcome(headers.get_one("username"), fail.clone())
.and_then(|u| {
option_to_outcome(
headers.get_one("password"), fail.clone())
.and_then(|p| {
Outcome::Success(
(String::from(u), String::from(p)))
})
})
}
Re-Organizing
Armed with this function we can start re-tooling our from_request
function. We'll start by gathering the header results and invoking and_then
. This unwraps the username and password:
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let headers_result =
read_auth_from_headers(&request.headers());
headers_result.and_then(|(u, p)| {
...
}
...
}
}
Now for the next step, we'll make a couple database calls. Both of our normal functions return Option
values. So for each, we'll create a failure Outcome
and invoke option_to_outcome
. We'll follow this up with a call to and_then
. First we get the user based on the username. Then we find their AuthInfo
using the ID.
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let headers_result =
read_auth_from_headers(&request.headers());
headers_result.and_then(|(u, p)| {
let conn_str = local_conn_string();
let maybe_user =
fetch_user_by_email(&conn_str, &String::from(u));
let fail1 =
(Status::NotFound, LoginError::UsernameDoesNotExist);
option_to_outcome(maybe_user, fail1)
.and_then(|user: UserEntity| {
let fail2 = (Status::MovedPermanently,
LoginError::WrongPassword);
option_to_outcome(
fetch_auth_info_by_user_id(
&conn_str, user.id), fail2)
})
.and_then(|auth_info: AuthInfoEntity| {
...
})
})
}
}
This gives us unwrapped authentication info. We can use this to compare the hash of the original password and return our final Outcome
!
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let headers_result =
read_auth_from_headers(&request.headers());
headers_result.and_then(|(u, p)| {
let conn_str = local_conn_string();
let maybe_user =
fetch_user_by_email(&conn_str, &String::from(u));
let fail1 =
(Status::NotFound, LoginError::UsernameDoesNotExist);
option_to_outcome(maybe_user, fail1)
.and_then(|user: UserEntity| {
let fail2 = (Status::MovedPermanently,
LoginError::WrongPassword);
option_to_outcome(
fetch_auth_info_by_user_id(
&conn_str, user.id), fail2)
})
.and_then(|auth_info: AuthInfoEntity| {
let hash = hash_password(&String::from(p));
if hash == auth_info.password_hash {
Outcome::Success( AuthenticatedUser{
user_id: auth_info.user_id})
} else {
Outcome::Failure(
(Status::Forbidden,
LoginError::WrongPassword))
}
})
})
}
}
Conclusion
Is this new solution that much better than our original? Well it avoids the "triangle of death" pattern with our code. But it's not necessarily that much shorter. Perhaps it's a little more cleaner on the whole though. Ultimately these code choices are up to you! Next time, we'll wrap up our current exploration of Rust by seeing how to profile our code in Rust.
This series has covered some more advanced topics in Rust. For a more in-depth introduction, check out our Rust Video Tutorial!
Appendix: Original Function
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>) -> Outcome<AuthenticatedUser, LoginError> {
let username = request.headers().get_one("username");
let password = request.headers().get_one("password");
match (username, password) {
(Some(u), Some(p)) => {
let conn_str = local_conn_string();
let maybe_user = fetch_user_by_email(&conn_str, &String::from(u));
match maybe_user {
Some(user) => {
let maybe_auth_info = fetch_auth_info_by_user_id(&conn_str, user.id);
match maybe_auth_info {
Some(auth_info) => {
let hash = hash_password(&String::from(p));
if hash == auth_info.password_hash {
Outcome::Success(AuthenticatedUser{user_id: 1})
} else {
Outcome::Failure((Status::Forbidden, LoginError::WrongPassword))
}
}
None => {
Outcome::Failure((Status::MovedPermanently, LoginError::WrongPassword))
}
}
}
None => Outcome::Failure((Status::NotFound, LoginError::UsernameDoesNotExist))
}
},
_ => Outcome::Failure((Status::BadRequest, LoginError::InvalidData))
}
}
}
Rocket Frontend: Templates and Static Assets
In the last few articles, we've been exploring the Rocket library for Rust web servers. Last time out, we tried a couple ways to add authentication to our web server. In this last Rocket-specific post, we'll explore some ideas around frontend templating. This will make it easy for you to serve HTML content to your users!
To explore the code for this article, head over to the "rocket_template" file on our Github repo! If you're still new to Rust, you might want to start with some simpler material. Take a look at our Rust Beginners Series as well!
Templating Basics
First, let's understand the basics of HTML templating. When our server serves out a webpage, we return HTML to the user for the browser to render. Consider this simple index page:
<html>
<head></head>
<body>
<p> Welcome to the site!</p>
</body>
</html>
But of course, each user should see some kind of custom content. For example, in our greeting, we might want to give the user's name. In an HTML template, we'll create a variable or sorts in our HTML, delineated by braces:
<html>
<head></head>
<body>
<p> Welcome to the site {{name}}!</p>
</body>
</html>
Now before we return the HTML to the user, we want to perform a substitution. Where we find the variable {{name}}
, we should replace it with the user's name, which our server should know.
There are many different libraries that do this, often through Javascript. But in Rust, it turns out the Rocket library has a couple easy templating integrations. One option is Tera, which was specifically designed for Rust. Another option is Handlebars, which is more native to Javascript, but also has a Rocket integration. The substitutions in this article are simple, so there's not actually much of a difference for us.
Returning a Template
So how do we configure our server to return this HTML data? To start, we have to attach a "Fairing" to our server, specifically for the Template
library. A Fairing is a server-wide piece of middleware. This is how we can allow our endpoints to return templates:
use rocket_contrib::templates::Template;
fn main() {
rocket::ignite()
.mount("/", routes![index, get_user])
.attach(Template::fairing())
.launch();
}
Now we can make our index
endpoint. It has no inputs, and it will return Rocket's Template
type.
#[get("/")]
fn index() -> Template {
...
}
We have two tasks now. First, we have to construct our context. This can be any "map-like" type with string information. We'll use a HashMap
, populating the name
value.
#[get("/")]
fn index() -> Template {
let context: HashMap<&str, &str> = [("name", "Jonathan")]
.iter().cloned().collect();
...
}
Now we have to render
our template. Let's suppose we have a "templates" directory at the root of our project. We can put the template we wrote above in the "index.hbs" file. When we call the render
function, we just give the name of our template and pass the context
!
#[get("/")]
fn index() -> Template {
let context: HashMap<&str, &str> = [("name", "Jonathan")]
.iter().cloned().collect();
Template::render("index", &context)
}
Including Static Assets
Rocket also makes it quite easy to include static assets as part of our routing system. We just have to mount
the static
route to the desired prefix when launching our server:
fn main() {
rocket::ignite()
.mount("/static", StaticFiles::from("static"))
.mount("/", routes![index, get_user])
.attach(Template::fairing())
.launch();
}
Now any request to a /static/...
endpoint will return the corresponding file in the "static" directory of our project. Suppose we have this styles.css
file:
p {
color: red;
}
We can then link to this file in our index template:
<html>
<head>
<link rel="stylesheet" type="text/css" href="static/styles.css"/>
</head>
<body>
<p> Welcome to the site {{name}}!</p>
</body>
</html>
Now when we fetch our index, we'll see that the text on the page is red!
Looping in our Database
Now for one last piece of integration with our database. Let's make a page that will show a user their basic information. This starts with a simple template:
<!-- templates/user.hbs -->
<html>
<head></head>
<body>
<p> User name: {{name}}</p>
<br>
<p> User email: {{email}}</p>
<br>
<p> User name: {{age}}</p>
</body>
</html>
We'll compose an endpoint that takes the user's ID as an input and fetches the user from the database:
#[get("/users/<uid>")]
fn get_user(uid: i32) -> Template {
let maybe_user = fetch_user_by_id(&local_conn_string(), uid);
...
}
Now we need to build our context from the user information. This will require a match
statement on the resulting user. We'll use Unknown
for the fields if the user doesn't exist.
#[get("/users/<uid>")]
fn get_user(uid: i32) -> Template {
let maybe_user = fetch_user_by_id(&local_conn_string(), uid);
let context: HashMap<&str, String> = {
match maybe_user {
Some(u) =>
[ ("name", u.name.clone())
, ("email", u.email.clone())
, ("age", u.age.to_string())
].iter().cloned().collect(),
None =>
[ ("name", String::from("Unknown"))
, ("email", String::from("Unknown"))
, ("age", String::from("Unknown"))
].iter().cloned().collect()
}
};
Template::render("user", &context)
}
And to wrap it up, we'll render
the "user" template! Now when users get directed to the page for their user ID, they'll see their information!
Conclusion
Next week, we'll go back to some of our authentication code. But we'll do so with the goal of exploring a more universal Rust idea. We'll see how functors and monads still find a home in Rust. We'll explore the functions that allow us to clean up heavy conditional code just as we could in Haskell.
For a more in-depth introduction to Rust basics, be sure to take a look at our Rust Video Tutorial!
Authentication in Rocket
Last week we enhanced our Rocket web server. We combined our server with our Diesel schema to enable a series of basic CRUD endpoints. This week, we'll continue this integration, but bring in some more cool Rocket features. We'll explore two different methods of authentication. First, we'll create a "Request Guard" to allow a form of Basic Authentication. Then we'll also explore Rocket's amazingly simple Cookies integration.
As always, you can explore the code for this series by heading to our Github repository. For this article specifically, you'll want to take a look at the rocket_auth.rs
file
If you're just starting your Rust journey, feel free to check out our Beginners Series as well!
New Data Types
To start off, let's make a few new types to help us. First, we'll need a new database table, auth_infos
, based on this struct:
#[derive(Insertable)]
pub struct AuthInfo {
pub user_id: i32,
pub password_hash: String
}
When the user creates their account, they'll provide a password. We'll store a hash of that password in our database table. Of course, you'll want to run through all the normal steps we did with Diesel to create this table. This includes having the corresponding Entity
type.
We'll also want a couple new form types to accept authentication information. First off, when we create a user, we'll now include the password in the form.
#[derive(FromForm, Deserialize)]
struct CreateInfo {
name: String,
email: String,
age: i32,
password: String
}
Second, when a user wants to login, they'll pass their username (email) and their password.
#[derive(FromForm, Deserialize)]
struct LoginInfo {
username: String,
password: String,
}
Both these types should derive FromForm
and Deserialize
so we can grab them out of "post" data. You might wonder, do we need another type to store the same information that already exists in User
and UserEntity
? It would be possible to write CreateInfo
to have a User
within it. But then we'd have to manually write the FromForm
instance. This isn't difficult, but it might be more tedious than using a new type.
Creating a User
So in the first place, we have to create our user so they're matched up with their password. This requires taking the CreateInfo
in our post request. We'll first unwrap the user fields and insert our User
object. This follows the patterns we've seen so far in this series with Diesel.
#[post("/users/create", format="json", data="<create_info>")]
fn create(db: State<String>, create_info: Json<CreateInfo>)
-> Json<i32> {
let user: User = User
{ name: create_info.name.clone(),
email: create_info.email.clone(),
age: create_info.age};
let connection = ...;
let user_entity: UserEntity = diesel::insert_into(users::table)...
…
}
Now we'll want a function for hashing our password. We'll use the SHA3 algorithm, courtesy of the rust-crypto
library:
fn hash_password(password: &String) -> String {
let mut hasher = Sha3::sha3_256();
hasher.input_str(password);
hasher.result_str()
}
We'll apply this function on the input password and attach it to the created user ID. Then we can insert the new AuthInfo
and return the created ID.
#[post("/users/create", format="json", data="<create_info>")]
fn create(db: State<String>, create_info: Json<CreateInfo>)
-> Json<i32> {
...
let user_entity: UserEntity = diesel::insert_into(users::table)...
let password_hash = hash_password(&create_info.password);
let auth_info: AuthInfo = AuthInfo
{user_id: user_entity.id, password_hash: password_hash};
let auth_info_entity: AuthInfoEntity =
diesel::insert_into(auth_infos::table)..
Json(user_entity.id)
}
Now whenever we create our user, they'll have their password attached!
Gating an Endpoint
Now that our user has a password, how do we gate endpoints on authentication? Well the first approach we can try is something like "Basic Authentication". This means that every authenticated request contains the username and the password. In our example we'll get these directly out of header elements. But in a real application you would want to double check that the request is encrypted before doing this.
But it would be tiresome to apply the logic of reading the headers in every handler. So Rocket has a powerful functionality called "Request Guards". Rocket has a special trait called FromRequest
. Whenever a particular type is an input to a handler function, it runs the from_request
function. This determines how to derive the value from the request. In our case, we'll make a wrapper type AuthenticatedUser
. This represents a user that has included their auth info in the request.
struct AuthenticatedUser {
user_id: i32
}
Now we can include this type in a handler signature. For this endpoint, we only allow a user to retrieve their data if they've logged in:
#[get("/users/my_data")]
fn login(db: State<String>, user: AuthenticatedUser)
-> Json<Option<UserEntity>> {
Json(fetch_user_by_id(&db, user.user_id))
}
Implementing the Request Trait
The trick of course is that we need to implement the FromRequest
trait! This is more complicated than it sounds! Our handler will have the ability to short-circuit the request and return an error. So let's start by specifying a couple potential login errors we can throw.
#[derive(Debug)]
enum LoginError {
InvalidData,
UsernameDoesNotExist,
WrongPassword
}
The from_request
function will take in a request and return an Outcome
. The outcome will either provide our authentication type or an error. The last bit of adornment we need on this is lifetime specifiers for the request itself and the reference to it.
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
...
}
}
Now the actual function definition involves several layers of case matching! It consists of a few different operations that have to query the request or query our database. For example, let's consider the first layer. We insist on having two headers in our request: one for the username, and one for the password. We'll use request.headers()
to check for these values. If either doesn't exist, we'll send a Failure
outcome with invalid data. Here's what that looks like:
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>)
-> Outcome<AuthenticatedUser, LoginError> {
let username = request.headers().get_one("username");
let password = request.headers().get_one("password");
match (username, password) {
(Some(u), Some(p)) => {
...
}
_ => Outcome::Failure(
(Status::BadRequest,
LoginError::InvalidData))
}
}
}
In the main branch of the function, we'll do 3 steps:
- Find the user in our database based on their email address/username.
- Find their authentication information based on the ID
- Hash the input password and compare it to the database hash
If we are successful, then we'll return a successful outcome:
Outcome::Success(AuthenticatedUser(user_id: user.id))
The number of match
levels required makes the function definition very verbose. So we've included it at the bottom as an appendix. We know how to take such a function and write it more cleanly in Haskell using monads. In a couple weeks, we'll use this function as a case study to explore Rust's monadic abilities.
Logging In with Cookies
In most applications though, we'll won't want to include the password in the request each time. In HTTP, "Cookies" provide a way to store information about a particular user that we can track on our server.
Rocket makes this very easy with the Cookies
type! We can always include this mutable type in our requests. It works like a key-value store, where we can access certain information with a key like "user_id"
. Since we're storing auth information, we'll also want to make sure it's encoded, or "private". So we'll use these functions:
add_private(...)
get_private(...)
remove_private(...)
Let's start with a "login" endpoint. This will take our LoginInfo
object as its post data, but we'll also have the Cookies
input:
#[post("/users/login", format="json", data="<login_info>")]
fn login_post(db: State<String>, login_info: Json<LoginInfo>, mut cookies: Cookies) -> Json<Option<i32>> {
...
}
First we have to make sure a user of that name exists in the database:
#[post("/users/login", format="json", data="<login_info>")]
fn login_post(
db: State<String>,
login_info: Json<LoginInfo>,
mut cookies: Cookies)
-> Json<Option<i32>> {
let maybe_user = fetch_user_by_email(&db, &login_info.username);
match maybe_user {
Some(user) => {
...
}
}
None => Json(None)
}
}
Then we have to get their auth info again. We'll hash the password and compare it. If we're successful, then we'll add the user's ID as a cookie. If not, we'll return None
.
#[post("/users/login", format="json", data="<login_info>")]
fn login_post(
db: State<String>,
login_info: Json<LoginInfo>,
mut cookies: Cookies)
-> Json<Option<i32>> {
let maybe_user = fetch_user_by_email(&db, &login_info.username);
match maybe_user {
Some(user) => {
let maybe_auth = fetch_auth_info_by_user_id(&db, user.id);
match maybe_auth {
Some(auth_info) => {
let hash = hash_password(&login_info.password);
if hash == auth_info.password_hash {
cookies.add_private(Cookie::new(
"user_id", u ser.id.to_string()));
Json(Some(user.id))
} else {
Json(None)
}
}
None => Json(None)
}
}
None => Json(None)
}
}
A more robust solution of course would loop in some error behavior instead of returning None
.
Using Cookies
Using our cookie now is pretty easy. Let's make a separate "fetch user" endpoint using our cookies. It will take the Cookies
object and the user ID as inputs. The first order of business is to retrieve the user_id
cookie and verify it exists.
#[get("/users/cookies/<uid>")]
fn fetch_special(db: State<String>, uid: i32, mut cookies: Cookies)
-> Json<Option<UserEntity>> {
let logged_in_user = cookies.get_private("user_id");
match logged_in_user {
Some(c) => {
...
},
None => Json(None)
}
}
Now we need to parse the string value as a user ID and compare it to the value from the endpoint. If they're a match, we just fetch the user's information from our database!
#[get("/users/cookies/<uid>")]
fn fetch_special(db: State<String>, uid: i32, mut cookies: Cookies)
-> Json<Option<UserEntity>> {
let logged_in_user = cookies.get_private("user_id");
match logged_in_user {
Some(c) => {
let logged_in_uid = c.value().parse::<i32>().unwrap();
if logged_in_uid == uid {
Json(fetch_user_by_id(&db, uid))
} else {
Json(None)
}
},
None => Json(None)
}
And when we're done, we can also post a "logout" request that will remove the cookie!
#[post("/users/logout", format="json")]
fn logout(mut cookies: Cookies) -> () {
cookies.remove_private(Cookie::named("user_id"));
}
Conclusion
We've got one more article on Rocket before checking out some different Rust concepts. So far, we've only dealt with the backend part of our API. Next week, we'll investigate how we can use Rocket to send templated HTML files and other static web content!
Maybe you're more experienced with Haskell but still need a bit of an introduction to Rust. We've got some other materials for you! Watch our Rust Video Tutorial for an in-depth look at the basics of the language!
Appendix: From Request Function
impl<'a, 'r> FromRequest<'a, 'r> for AuthenticatedUser {
type Error = LoginError;
fn from_request(request: &'a Request<'r>) -> Outcome<AuthenticatedUser, LoginError> {
let username = request.headers().get_one("username");
let password = request.headers().get_one("password");
match (username, password) {
(Some(u), Some(p)) => {
let conn_str = local_conn_string();
let maybe_user = fetch_user_by_email(&conn_str, &String::from(u));
match maybe_user {
Some(user) => {
let maybe_auth_info = fetch_auth_info_by_user_id(&conn_str, user.id);
match maybe_auth_info {
Some(auth_info) => {
let hash = hash_password(&String::from(p));
if hash == auth_info.password_hash {
Outcome::Success(AuthenticatedUser{user_id: 1})
} else {
Outcome::Failure((Status::Forbidden, LoginError::WrongPassword))
}
}
None => {
Outcome::Failure((Status::MovedPermanently, LoginError::WrongPassword))
}
}
}
None => Outcome::Failure((Status::NotFound, LoginError::UsernameDoesNotExist))
}
},
_ => Outcome::Failure((Status::BadRequest, LoginError::InvalidData))
}
}
}
Joining Forces: An Integrated Rust Web Server
We've now explored a couple different libraries for some production tasks in Rust. A couple weeks ago, we used Diesel to create an ORM for some database types. And then last week, we used Rocket to make a basic web server to respond to basic requests. This week, we'll put these two ideas together! We'll use some more advanced functionality from Rocket to make some CRUD endpoints for our database type. Take a look at the code on Github here!
If you've never written any Rust, you should start with the basics though! Take a look at our Rust Beginners Series!
Database State and Instances
Our first order of business is connecting to the database from our handler functions. There are some direct integrations you can check out between Rocket, Diesel, and other libraries. These can provide clever ways to add a connection argument to any handler.
But for now we're going to keep things simple. We'll re-generate the PgConnection
within each endpoint. We'll maintain a "stateful" connection string to ensure they all use the same database.
Our Rocket server can "manage" different state elements. Suppose we have a function that gives us our database string. We can pass that to our server at initialization time.
fn local_conn_string() -> String {...}
fn main() {
rocket::ignite()
.mount("/", routes![...])
.manage(local_conn_string())
.launch();
}
Now we can access this String
from any of our endpoints by giving an input the State<String>
type. This allows us to create our connection:
#[get(...)]
fn fetch_all_users(database_url: State<String>) -> ... {
let connection = pgConnection.establish(&database_url)
.expect("Error connecting to database!");
...
}
Note: We can't use the PgConnection
itself because stateful types need to be thread safe.
So any other of our endpoints can now access the same database. Before we start writing these, we need a couple things first though. Let's recall that for our Diesel ORM we made a User
type and a UserEntity
type. The first is for inserting/creating, and the second is for querying. We need to add some instances to those types so they are compatible with our endpoints. We want to have JSON instances (Serialize, Deserialize), as well as FromForm
for our User
type:
#[derive(Insertable, Deserialize, Serialize, FromForm)]
#[table_name="users"]
pub struct User {
...
}
#[derive(Queryable, Serialize)]
pub struct UserEntity {
...
}
Now let's see how we get these types from our endpoints!
Retrieving Users
We'll start with a simple endpoint to fetch all the different users in our database. This will take no inputs, except our stateful database URL. It will return a vector of UserEntity
objects, wrapped in Json
.
#[get("/users/all")]
fn fetch_all_users(database_url: State<String>)
-> Json<Vec<UserEntity>> {
...
}
Now all we need to do is connect to our database and run the query function. We can make our users vector into a Json
object by wrapping with Json()
. The Serialize
instance lets us satisfy the Responder
trait for the return value.
#[get("/users/all")]
fn fetch_all_users(database_url: State<String>)
-> Json<Vec<UserEntity>> {
let connection = PgConnection::establish(&database_url)
.expect("Error connecting to database!");
Json(users.load::<UserEntity>(&connection)
.expect("Error loading users"))
}
Now for getting individual users. Once again, we'll wrap a response in JSON. But this time we'll return an optional, single, user. We'll use a dynamic capture parameter in the URL for the User ID.
#[get("/users/<uid>")]
fn fetch_user(database_url: State<String>, uid: i32)
-> Option<Json<UserEntity>> {
let connection = ...;
...
}
We'll want to filter on the users table by the ID. This will give us a list of different results. We want to specify this vector as mutable. Why? In the end, we want to return the first user. But Rust's memory rules mean we must either copy or move this item. And we don't want to move a single item from the vector without moving the whole vector. So we'll remove the head from the vector entirely, which requires mutability.
#[get("/users/<uid>")]
fn fetch_user(database_url: State<String>, uid: i32)
-> Option<Json<UserEntity>> {
let connection = ...;
use rust_web::schema::users::dsl::*;
let mut users_by_id: Vec<UserEntity> =
users.filter(id.eq(uid))
.load::<UserEntity>(&connection)
.expect("Error loading users");
...
}
Now we can do our case analysis. If the list is empty, we return None
. Otherwise, we'll remove the user from the vector and wrap it.
#[get("/users/<uid>")]
fn fetch_user(database_url: State<String>, uid: i32) -> Option<Json<UserEntity>> {
let connection = ...;
use rust_web::schema::users::dsl::*;
let mut users_by_id: Vec<UserEntity> =
users.filter(id.eq(uid))
.load::<UserEntity>(&connection)
.expect("Error loading users");
if users_by_id.len() == 0 {
None
} else {
let first_user = users_by_id.remove(0);
Some(Json(first_user))
}
}
Create/Update/Delete
Hopefully you can see the pattern now! Our queries are all pretty simple. So our endpoints all follow a similar pattern. Connect to the database, run the query and wrap the result. We can follow this process for the remaining three endpoints in a basic CRUD setup. Let's start with "Create":
#[post("/users/create", format="application/json", data = "<user>")]
fn create_user(database_url: State<String>, user: Json<User>)
-> Json<i32> {
let connection = ...;
let user_entity: UserEntity = diesel::insert_into(users::table)
.values(&*user)
.get_result(&connection).expect("Error saving user");
Json(user_entity.id)
}
As we discussed last week, we can use data
together with Json
to specify the form data in our post request. We de-reference the user with *
to get it out of the JSON wrapper. Then we insert the user and wrap its ID to send back.
Deleting a user is simple as well. It has the same dynamic path as fetching a user. We just make a delete
call on our database instead.
#[delete("/users/<uid>")]
fn delete_user(database_url: State<String>, uid: i32) -> Json<i32> {
let connection = ...;
use rust_web::schema::users::dsl::*;
diesel::delete(users.filter(id.eq(uid)))
.execute(&connection)
.expect("Error deleting user");
Json(uid)
}
Updating is the last endpoint, which takes a put
request. The endpoint mechanics are just like our other endpoints. We use a dynamic path component to get the user's ID, and then provide a User
body with the updated field values. The only trick is that we need to expand our Diesel knowledge a bit. We'll use update
and set
to change individual fields on an item.
#[put("/users/<uid>/update", format="json", data="<user>")]
fn update_user(
database_url: State<String>, uid: i32, user: Json<User>)
-> Json<UserEntity> {
let connection = ...;
use rust_web::schema::users::dsl::*;
let updated_user: UserEntity =
diesel::update(users.filter(id.eq(uid)))
.set((name.eq(&user.name),
email.eq(&user.email),
age.eq(user.age)))
.get_result::<UserEntity>(&connection)
.expect("Error updating user");
Json(updated_user)
}
The other gotcha is that we need to use references (&
) for the string fields in the input user. But now we can add these routes to our server, and it will manipulate our database as desired!
Conclusion
There are still lots of things we could improve here. For example, we're still using .expect
in many places. From the perspective of a web server, we should be catching these issues and wrapping them with "Err 500". Rocket also provides some good mechanics for fixing that. Next week though, we'll pivot to another server problem that Rocket solves adeptly: authentication. We should restrict certain endpoints to particular users. Rust provides an authentication scheme that is neatly encoded in the type system!
For a more in-depth introduction to Rust, watch our Rust Video Tutorial. It will take you through a lot of key skills like understanding memory and using Cargo!
Rocket: Web Servers in Rust!
Welcome back to our series on building simple apps in Rust. Last week, we explored the Diesel library which gave us an ORM for database interaction. For the next few weeks, we'll be trying out the Rocket library, which makes it quick and easy to build a web server in Rust! This is comparable to the Servant library in Haskell, which we've explored before.
This week, we'll be working on the basic building blocks of using this library. The reference code for this article is available here on Github!
Rust combines some of the neat functional ideas of Haskell with some more recognizable syntax from C++. To learn more of the basics, take a look at our Rust Beginners Series!
Our First Route
To begin, let's make a simple "hello world" endpoint for our server. We don't specify a full API definition all at once like we do with Servant. But we do use a special macro before the endpoint function. This macro describes the route's method and its path.
#[get("/hello")]
fn index() -> String {
String::from("Hello, world!")
}
So our macro tells us this is a "GET" endpoint and that the path is /hello
. Then our function specifies a String
as the return value. We can, of course, have different types of return values, which we'll explore those more as the series goes on.
Launching Our Server
Now this endpoint is useless until we can actually run and launch our server. To do this, we start by creating an object of type Rocket
with the ignite()
function.
fn main() {
let server: Rocket = rocket::ignite();
...
}
We can then modify our server by "mounting" the routes we want. The mount
function takes a base URL path and a list of routes, as generated by the routes
macro. This function returns us a modified server:
fn main() {
let server: Rocket = rocket::ignite();
let server2: Rocket = server.mount("/", routes![index]);
}
Rather than create multiple server objects, we'll just compose these different functions. Then to launch our server, we use launch
on the final object!
fn main() {
rocket::ignite().mount("/", routes![index]).launch();
}
And now our server will respond when we ping it at localhost:8000/hello
! We could, of course, use a different base path. We could even assign different routes to different bases!
fn main() {
rocket::ignite().mount("/api", routes![index]).launch();
}
Now it will respond at /api/hello
.
Query Parameters
Naturally, most endpoints need inputs to be useful. There are a few different ways we can do this. The first is to use path components. In Servant, we call these CaptureParams
. With Rocket, we'll format our URL to have brackets around the variables we want to capture. Then we can assigned them with a basic type in our endpoint function:
#[get("/math/<name>")]
fn hello(name: &RawStr) -> String {
format!("Hello, {}!", name.as_str())
}
We can use any type that satisfies the FromParam
trait, including a RawStr
. This is a Rocket specific type wrapping string-like data in a raw format. With these strings, we might want to apply some sanitization processes on our data. We can also use basic numeric types, like i32
.
#[get("/math/<first>/<second>")]
fn add(first: i32, second: i32) -> String {
String::from(format!("{}", first + second))
}
This endpoint will now return "11" when we ping /math/5/6
.
We can also use "query parameters", which all go at the end of the URL. These need the FromFormValue
trait, rather than FromParam
. But once again, RawStr
and basic numbers work fine.
#[get("/math?<first>&<second>)]
fn multiply(first: i32, second: i32) {
String::from(format!("{}", first * second)
}
Now we'll get "30" when we ping /math?5&6
.
Post Requests
The last major input type we'll deal with is post request data. Suppose we have a basic user type:
struct User {
name: String,
email: String,
age: i32
}
We'll want to derive various classes for it so we can use it within endpoints. From the Rust "Serde" library we'll want Deserialize
and Serialize
so we can make JSON elements out of it. Then we'll also want FromForm
to use it as post request data.
#[derive(FromForm, Deserialize, Serialize)]
struct User {
...
}
Now we can make our endpoint, but we'll have to specify the "format" as JSON and the "data" as using our "user" type.
#[post("/users/create", format="json", data="<user>")]
fn create_user(user: Json<User>) -> String {
...
}
We need to provide the Json
wrapper for our input type, but we can use it as though it's a normal User
. For now, we'll just return a string echoing the user's information back to us. Don't forget to add each new endpoint to the routes
macro in your server definition!
#[post("/users/create", format="json", data="<user>")]
fn create_user(user: Json<User>) -> String {
String::from(format!(
"Created user: {} {} {}", user.name, user.email, user.age))
}
Conclusion
Next time, we'll explore making a more systematic CRUD server. We'll add database integration and see some other tricks for serializing data and maintaining state. Then we'll explore more advanced topics like authentication, static files, and templating!
If you're going to be building a web application in Rust, you'd better have a solid foundation! Watch our Rust Video Tutorial to get an in-depth introduction!
Diesel: A Rust-y ORM
Last week on Monday Morning Haskell we took our first step into some real world tasks with Rust. We explored the simple Rust Postgres library to connect to a database and run some queries. This week we're going to use Diesel, a library with some cool ORM capabilities. It's a bit like the Haskell library Persistent, which you can explore more in our Real World Haskell Series.
For a more in-depth look at the code for this article, you should take a look at our Github Repository! You'll want to look at the files referenced below and also at the executable here.
Diesel CLI
Our first step is to add Diesel as a dependency in our program. We briefly discussed Cargo "features" in last week's article. Diesel has separate features for each backend you might use. So we'll specify "postgres". Once again, we'll also use a special feature for the chrono
library so we can use timestamps in our database.
[[dependencies]]
diesel={version="1.4.4", features=["postgres", "chrono"]}
But there's more! Diesel comes with a CLI that helps us manage our database migrations. It also will generate some of our schema code. Just as we can install binaries with Stack using stack install
, we can do the same with Cargo. We only want to specify the features we want. Otherwise it will crash if we don't have the other databases installed.
>> cargo install diesel_cli --no-default-features --features postgres
Now we can start using the program to setup our project to generate our migrations. We begin with this command.
>> diesel setup
This creates a couple different items in our project directory. First, we have a "migrations" folder, where we'll put some SQL code. Then we also get a schema.rs
file in our src
directory. Diesel will automatically generate code for us in this file. Let's see how!
Migrations and Schemas
When using Persistent in Haskell, we defined our basic types in a single Schema file using a special template language. We could run migrations on our whole schema programmatically, without our own SQL. But it is difficult to track more complex migrations as your schema evolves.
Diesel is a bit different. Unfortunately, we have to write our own SQL. But, we'll do so in a way that it's easy to take more granular actions on our table. Diesel will then generate a schema file for us. But we'll still need some extra work to get the Rust types we'll need. To start though, let's use Diesel to generate our first migration. This migration will create our "users" table.
>> diesel migration generate create_users
This creates a new folder within our "migrations" directory for this "create_users" migration. It has two files, up.sql
and down.sql
. We start by populating the up.sql
file to specify the SQL we need to run the migration.
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER NOT NULL
)
Then we also want the down.sql
file to contain SQL that reverses the migration.
DROP TABLE users CASCADE;
Once we've written these, we can run our migration!
>> diesel migration run
We can then undo the migration, running the code in down.sql
with this command:
>> diesel migration redo
The result of running our migration is that Diesel populates the schema.rs
file. This file uses the table
macro that generates helpful types and trait instances for us. We'll use this a bit when incorporating the table into our code.
table! {
users (id) {
id -> Int4,
name -> Text,
email -> Text,
age -> Int4,
}
}
While we're at it, let's make one more migration to add an articles table.
-- migrations/create_articles/up.sql
CREATE TABLE articles (
id SERIAL PRIMARY KEY,
title TEXT NOT NULL,
body TEXT NOT NULL,
published_at TIMESTAMP WITH TIME ZONE NOT NULL,
author_id INTEGER REFERENCES users(id) NOT NULL
)
-- migrations/create_articles/down.sql
DROP TABLE articles;
Then we can once again use diesel migration run
.
Model Types
Now, while Diesel will generate a lot of useful code for us, we still need to do some work on our own. We have to create our own structs for the data types to take advantage of the instances we get. With Persistent, we got these for free. Persistent also used a wrapper Entity
type, which attached a Key
to our actual data.
Diesel doesn't have the notion of an entity. We have to manually make two different types, one with the database key and one without. For the "Entity" type which has the key, we'll derive the "Queryable" class. Then we can use Diesel's functions to select items from the table.
#[derive(Queryable)]
pub struct UserEntity {
pub id: i32
pub name: String,
pub email: String,
pub age: i32
}
We then have to declare a separate type that implements "Insertable". This doesn't have the database key, since we don't know the key before inserting the item. This should be a copy of our entity type, but without the key field. We use a second macro to tie it to the users
table.
#[derive(Insertable)]
#[table_name="users"]
pub struct User {
pub name: String,
pub email: String,
pub age: i32
}
Note that in the case of our foreign key type, we'll use a normal integer for our column reference. In Persistent we would have a special Key
type. We lose some of the semantic meaning of this field by doing this. But it can help keep more of our code separate from this specific library.
Making Queries
Now that we have our models in place, we can start using them to write queries. First, we need to make a database connection using the establish
function. Rather than using the ?
syntax, we'll use .expect
to unwrap our results in this article. This is less safe, but a little easier to work with.
fn create_connection() -> PgConnection {
let database_url = "postgres://postgres:postgres@localhost/rust_db";
PgConnection::establish(&database_url)
.expect("Error Connecting to database")
}
fn main() {
let connection: PgConnection = create_connection();
...
}
Let's start now with insertion. Of course, we begin by creating one of our "Insertable" User
items. We can then start writing an insertion query with the Diesel function insert_into
.
Diesel's query functions are composable. We add different elements to the query until it is complete. With an insertion, we use values
combined with the item we want to insert. Then, we call get_result
with our connection. The result of an insertion is our "Entity" type.
fn create_user(conn: &PgConnection) -> UserEntity {
let u = User
{ name = "James".to_string()
, email: "james@test.com".to_string()
, age: 26};
diesel::insert_into(users::table).values(&u)
.get_result(conn).expect("Error creating user!")
}
Selecting Items
Selecting items is a bit more complicated. Diesel generates a dsl
module for each of our types. This allows us to use each field name as a value within "filters" and orderings. Let's suppose we want to fetch all the articles written by a particular user. We'll start our query on the articles
table and call filter
to start building our query. We can then add a constraint on the author_id
field.
fn fetch_articles(conn: &PgConnection, uid: i32) -> Vec<ArticleEntity> {
use rust_web::schema::articles::dsl::*;
articles.filter(author_id.eq(uid))
...
We can also add an ordering to our query. Notice again how these functions compose. We also have to specify the return type we want when using the load
function to complete our select query. The main case is to return the full entity. This is like SELECT * FROM
in SQL lingo. Applying load
will give us a vector of these items.
fn fetch_articles(conn: &PgConnection, uid: i32) -> Vec<ArticleEntity> {
use rust_web::schema::articles::dsl::*;
articles.filter(author_id.eq(uid))
.order(title)
.load::<ArticleEntity>(conn)
.expect("Error loading articles!")
}
But we can also specify particular fields that we want to return. We'll see this in the final example, where our result type is a vector of tuples. This last query will be a join between our two tables. We start with users
and apply the inner_join
function.
fn fetch_all_names_and_titles(conn: &PgConnection) -> Vec<(String, String)> {
use rust_web::schema::users::dsl::*;
use rust_web::schema::articles::dsl::*;
users.inner_join(...
}
Then we join it to the articles table on the particular ID field. Because both of our tables have id
fields, we have to namespace it to specify the user's ID field.
fn fetch_all_names_and_titles(conn: &PgConnection) -> Vec<(String, String)> {
use rust_web::schema::users::dsl::*;
use rust_web::schema::articles::dsl::*;
users.inner_join(
articles.on(author_id.eq(rust_web::schema::users::dsl::id)))...
}
Finally, we load
our query to get the results. But notice, we use select
and only ask for the name
of the User and the title
of the article. This gives us our final values, so that each element is a tuple of two strings.
fn fetch_all_names_and_titles(conn: &PgConnection) -> Vec<(String, String)> {
use rust_web::schema::users::dsl::*;
use rust_web::schema::articles::dsl::*;
users.inner_join(
articles.on(author_id.eq(rust_web::schema::users::dsl::id)))
.select((name, title)).load(conn).expect("Error on join query!")
}
Conclusion
For my part, I prefer the functionality provided by Persistent in Haskell. But Diesel's method of providing a separate CLI to handle migrations is very cool as well. And it's good to see more sophisticated functionality in this relatively new language.
If you're still new to Rust, we have some more beginner-related material. Read our Rust Beginners Series or better yet, watch our Rust Video Tutorial!
Basic Postgres Data in Rust
For our next few articles, we're going to be exploring some more advanced concepts in Rust. Specifically, we'll be looking at parallel ideas from our Real World Haskell Series. In these first couple weeks, we'll be exploring how to connect Rust and a Postgres database. To start, we'll use the Rust Postgres library. This will help us create a basic database connection so we can make simple queries. You can see all the code for this article in action by looking at our RustWeb repository. Specifically, you'll want to check out the file pg_basic.rs
.
If you're new to Rust, we have a couple beginner resources for you to start out with. You can read our Rust Beginners Series to get a basic introduction to the language. Or for some more in-depth explanations, you can watch our Rust Video Tutorial!!
Creating Tables
We'll start off by making a client object to connect to our database. This uses a query string like we would with any Postgres library.
let conn_string = "host=localhost port=5432 user=postgres";
let mut client : Client = Client::connect(conn_string, NoTls)?;
Note that the connect
function generally returns a Result<Client, Error>
. In Haskell, we would write this as Either Error Client
. By using ?
at the end of our call, we can immediately unwrap the Client
. The caveat on this is that it only compiles if the whole function returns some kind of Result<..., Error>
. This is an interesting monadic behavior Rust gives us. Pretty much all our functions in this article will use this ?
behavior.
Now that we have a client, we can use it to run queries. The catch is that we have to know the Raw SQL ourselves. For example, here's how we can create a table to store some users:
client.batch_execute("\
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER NOT NULL
)
")?;
Inserting with Interpolation
A raw query like that with no real result is the simplest operation we can perform. But, any non-trivial program will require us to customize the queries programmatically. To do this we'll need to interpolate values into the middle of our queries. We can do this with execute
(as opposed to batch_execute
).
Let's try creating a user. As with batch_execute
, we need a query string. This time, the query string will contain values like $1
, $2
that we'll fill in with variables. We'll provide these variables with a list of references. Here's what it looks like with a user:
let name = "James";
let email = "james@test.com";
let age = 26;
client.execute(
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3)",
&[&name, &email, &age],
)?;
Again, we're using a raw query string. All the values we interpolate must implement the specific class postgres_types::ToSql
. We'll see this a bit later.
Fetching Results
The last main type of query we can perform is to fetch our results. We can use our client to call the query
function, which returns a vector of Row
objects:
for row: Row in client.query("SELECT * FROM users"), &[])? {
...
}
For more complicated SELECT
statements we would interpolate parameters, as with insertion above. The Row
has different Columns
for accessing the data. But in our case it's a little easier to use get
and the index to access the different fields. Like our Raw SQL calls, this is unsafe in a couple ways. If we use an out of bounds index, we'll get a crash. And if we try to cast to the wrong data type, we'll also run into problems.
for row: Row in client.query("SELECT * FROM users"), &[])? {
let id: i32 = row.get(0);
let name: &str = row.get(1);
let email: &str = row.get(2);
let age: i32 = row.get(3);
...
}
We could then use these individual values to populate whatever data types we wanted on our end.
Joining Tables
If we want to link two tables together, of course we'll also have to know how to do this with Raw SQL. For example, we can make our articles table:
client.batch_execute("\
CREATE TABLE articles (
id SERIAL PRIMARY KEY,
title TEXT NOT NULL,
body TEXT NOT NULL,
published_at TIMESTAMP WITH TIME ZONE NOT NULL,
author_id INTEGER REFERENCES users(id)
)
")?;
Then, after retrieving a user's ID, we can insert an article written by that user.
for row: Row in client.query("SELECT * FROM users"), &[])? {
let id: i32 = row.get(0);
let title: &str = "A Great Article!";
let body: &str = "You should share this with friends.";
let cur_time: DateTime<Utc> = Utc::now();
client.execute(
"INSERT INTO articles (title, body, published_at, author_id) VALUES ($1, $2, $3, $4)",
&[&title, &body, &cur_time, &id]
)?;
}
One of this tricky parts is that this won't compile if you only use the basic postgres
dependency in Rust! There isn't a native ToSql
instance for the DateTime<Utc>
type. However, Rust dependencies can have specific "features". This concept doesn't really exist in Haskell, except through extra packages. You'll need to specify the with-chrono
feature for the version of the chrono
library you use. This feature, or sub-dependency contains the necessary ToSql
instance. Here's what the structure looks like in our Cargo.toml
:
[dependencies]
chrono="0.4"
postgres={version="0.17.3", features=["with-chrono-0_4"]}
After this, our code will compile!
Runtime Problems
Now there are lots of reasons we wouldn't want to use a library like this in a formal project. One of the big principles of Rust (and Haskell) is catching errors at compile time. And writing out functions with lots of raw SQL like this makes our program very prone to runtime errors. I encountered several of these as I was writing this small program! At one point, I started writing the SELECT
query and absentmindedly forgot to complete it until I ran my program!
At another point, I couldn't decide what timestamp format to use in Postgres. I went back and forth between using a TIMESTAMP
or just an INTEGER
for the published_at
field. I needed to coordinate the SQL for both the table creation query and the fetching query. I often managed to change one but not the other, resulting in annoying runtime errors. I finally discovered I needed TIMESTAMP WITH TIME ZONE
and not merely TIMESTAMP
. This was a rather painful process with this setup.
Conclusion
Next week, we'll explore Diesel, a library that lets us use schemas to catch more of these issues at compile time. The framework is more comparable to Persistent in Haskell. It gives us an ORM (Object Relational Mapping) so that we don't have to write raw SQL. This approach is much more suited to languages like Haskell and Rust!
To try out tasks like this in Haskell, take a look at our Production Checklist! It includes a couple different libraries for interacting with databases using ORMs.
Preparing for Rust!
Next week, we're going to change gears a bit and start some interesting projects with Rust! Towards the end of last year, we dabbled a bit with Rust and explored some of the basics of the language. In our next series of blog articles, we're going to take a deep dive into some more advanced concepts.
We'll explore several different Rust libraries in various topics. We'll consider data serialization, web servers and databases, among other. We'll build a couple small apps, and compare the results to our earlier work with Haskell.
To get ready for this series, you should brush up on your Rust basics! To help, we've wrapped up our Rust content into a permanent series on the Beginners page! Here's an overview of that series:
Part 1: Basic Syntax
We start out by learning about Rust's syntax. We'll see quite a few differences to Haskell. But there are also some similarities in unexpected places.
Part 2: Memory Management
One of the major things that sets Rust apart from other languages is how it manages memory. In the second part, we'll learn a bit about how Rust's memory system works.
Part 3: Data Types
In the third part of the series, we'll explore how to make our own data types in Rust. We'll see that Rust borrows some of Haskell's neat ideas!
Part 4: Cargo Package Manager
Cargo is Rust's equivalent of Stack and Cabal. It will be our package and dependency manager. In part 4, we see how to make basic Rust projects using Cargo.
Part 5: Lifetimes and Collections
In the final part, we'll look at some more advanced collection types in Rust. Because of Rust's memory model, we'll need some special rules for handling items in collections. This will lead us to the idea of lifetimes.
If you prefer video content, our Rust Video Tutorial also provides a solid foundation. It goes through all the topics in this series, starting from installation. Either way, stay tuned for new blog content, starting next week!
Summer Course Sale!
This week we have some exciting news! Back in March, we opened our Practical Haskell course for enrollment. The first round of students has had a chance to go through the course. So we're now opening it up for general enrollment!
This course goes through some more practical concepts and libraries you might use on a real world project. Here's a sneak peak at some of the skills you'll learn:
- Making a web server with Persistent and Servant
- Deploying a Haskell project using Heroku and Circle CI
- Making a web frontend with Elm, and connecting it to the Haskell backend
- Using Monad Transformers and Free Effects to organize our application
- Test driven development in Haskell
As a special bonus, for this week only, both of our courses are on sale, $100 off their normal prices! So if you're not ready for Practical Haskell, you can take a look at Haskell From Scratch. With that said, if you buy either course now, you'll have access to all the materials indefinitely! Prices will go back to normal after this Sunday, so head to the course pages now!
Next week, we'll start getting back into the swing of things by reviewing some of our Rust basics!
Mid-Summer Break, Open AI Gym Series!
We're taking a little bit of a mid-summer break from new content here at MMH. But we have done some extra work in organizing the site! Last week we wrapped up our series on Haskell and the Open AI Gym. We've now added that series as a permanent fixture on the advanced section of the page!
Here's a quick summary of the series:
Part 1: Frozen Lake Primer
The first part introduces the Open AI framework and goes through the Frozen lake example. It presents the core concept of an environment.
Part 2: Frozen Lake in Haskell
In the second part, we write a basic version of Frozen Lake in Haskell.
Part 3: Blackjack
Next, we expand on our knowledge of games and environments to write a second game. This one based on casino Blackjack, and it will start to show us common elements in games.
Part 4: Q-Learning
Now we start getting into the ideas of reinforcement learning. We'll explore Q-Learning, one of the simplest techniques in this field. We'll apply this approach to both of our games.
Part 5: Generalized Environments
Now that we've seen the learning process in action, we can start generalizing our games. We'll create an abstract notion of what an Environment
is. Just as Python has a specific API for their games, so will we! In true Haskell fashion, we'll represent this API with a type family!
Part 6: Q-Learning with Tensors in Python
In part 6, we'll take our Q-learning process a step further by using TensorFlow. We'll see how we can learn a more general function than we had before. We'll start this process in Python, where the mathematical operations are more clear.
Part 7: Q-Learning with Tensors in Haskell
Once we know how Q-Learning works with Python, we'll apply these techniques in Haskell as well! Once you get here, you'd better be ready to use your Haskell TensorFlow skills!
Part 8: Rendering with Gloss
In the final part of the series, we'll see how we can use the Gloss library to render our Haskell games!
You can take a look at the series summary page for more details!
In a couple weeks, we'll be back, this time with some fresh Rust content! Take a look at our Rust Video Tutorial to get a headstart on that!
Rendering Frozen Lake with Gloss!
We've spent the last few weeks exploring some of the ideas in the Open AI Gym framework. We made a couple games, generalized them, and applied some machine learning techniques. When it comes to rendering our games though, we're still relying on a very basic command line text format.
But if we want to design agents for more visually appealing games, we'll need a better solution! Last year, we spent quite a lot of time learning about the Gloss library. This library makes it easy to create simple games and render them using OpenGL. Take a look at this article for a summary of our work there and some links to the basics.
In this article, we'll explore how we can draw some connections between Gloss and our Open AI Gym work. We'll see how we can take the functions we've already written and use them within Gloss!
Gloss Basics
The key entrypoint for a Gloss game is the play
function. At its core is the world
type parameter, which we'll define for ourselves later.
play :: Display -> Color -> Int
-> world
-> (world -> Picture)
-> (Event -> world -> world)
-> (Float -> world -> world)
-> IO ()
We won't go into the first three parameters. But the rest are important. The first is our initial world state. The second is our rendering function. It creates a Picture
for the current state. Then comes an "event handler". This takes user input events and updates the world based on the actions. Finally there is the update function. This changes the world based on the passage of time, rather than specific user inputs.
This structure should sound familiar, because it's a lot like our Open AI environments! The initial world is like the "reset" function. Then both systems have a "render" function. And the update functions are like our stepEnv
function.
The main difference we'll see is that Gloss's functions work in a pure way. Recall our "environment" functions use the "State" monad. Let's explore this some more.
Re-Writing Environment Functions
Let's take a look at the basic form of these environment functions, in the Frozen Lake context:
resetEnv :: (Monad m) => StateT FrozenLakeEnvironment m Observation
stepEnv :: (Monad m) =>
Action -> StateT FrozenLakeEnvironment m (Observation, Double, Bool)
renderEnv :: (MonadIO m) => StateT FrozenLakeEnvironment m ()
These all use State
. This makes it easy to chain them together. But if we look at the implementations, a lot of them don't really need to use State
. They tend to unwrap the environment at the start with get
, calculate new results, and then have a final put
call.
This means we can rewrite them to fit more within Gloss's pure structure! We'll ignore rendering, since that will be very different. But here are some alternate type signatures:
resetEnv' :: FrozenLakeEnvironment -> FrozenLakeEnvironment
stepEnv' :: Action -> FrozenLakeEnvironment
-> (FrozenLakeEnvironment, Double, Bool)
We'll exclude Observation
as an output, since the environment contains that through currentObservation
. The implementation for each of these looks like the original. Here's what resetting looks like:
resetEnv' :: FrozenLakeEnvironment -> FrozenLakeEnvironment
resetEnv' fle = fle
{ currentObservation = 0
, previousAction = Nothing
}
Now for stepping our environment forward:
stepEnv' :: Action -> FrozenLakeEnvironment -> (FrozenLakeEnvironment, Double, Bool)
stepEnv' act fle = (finalEnv, reward, done)
where
currentObs = currentObservation fle
(slipRoll, gen') = randomR (0.0, 1.0) (randomGenerator fle)
allLegalMoves = legalMoves currentObs (dimens fle)
numMoves = length allLegalMoves - 1
(randomMoveIndex, finalGen) = randomR (0, numMoves) gen'
newObservation = ... -- Random move, or apply the action
(done, reward) = case (grid fle) A.! newObservation of
Goal -> (True, 1.0)
Hole -> (True, 0.0)
_ -> (False, 0.0)
finalEnv = fle
{ currentObservation = newObservation
, randomGenerator = finalGen
, previousAction = Just act
}
What's even better is that we can now rewrite our original State
functions using these!
resetEnv :: (Monad m) => StateT FrozenLakeEnvironment m Observation
resetEnv = do
modify resetEnv'
gets currentObservation
stepEnv :: (Monad m) =>
Action -> StateT FrozenLakeEnvironment m (Observation, Double, Bool)
stepEnv act = do
fle <- get
let (finalEnv, reward, done) = stepEnv' act fle
put finalEnv
return (currentObservation finalEnv, reward, done)
Implementing Gloss
Now let's see how this ties in with Gloss. It might be tempting to use our Environment
as the world
type. But it can be useful to attach other information as well. For one example, we can also include the current GameResult
, telling us if we've won, lost, or if the game is still going.
data GameResult =
GameInProgress |
GameWon |
GameLost
deriving (Show, Eq)
data World = World
{ environment :: FrozenLakeEnvironment
, gameResult :: GameResult
}
Now we can start building the other pieces of our game. There aren't really any "time" updates in our game, except to update the result based on our location:
updateWorldTime :: Float -> World -> World
updateWorldTime _ w = case tile of
Goal -> World fle GameWon
Hole -> World fle GameLost
_ -> w
where
fle = environment w
obs = currentObservation fle
tile = grid fle A.! obs
When it comes to handling inputs, we need to start with the case of restarting the game. When the game isn't InProgress
, only the "enter" button matters. This resets everything, using resetEnv'
:
handleInputs :: Event -> World -> World
handleInputs event w
| gameResult w /= GameInProgress = case event of
(EventKey (SpecialKey KeyEnter) Down _ _) ->
World (resetEnv' fle) GameInProgress
_ -> w
...
Now we handle each directional input key. We'll make a helper function at the bottom that does the business of calling stepEnv'
.
handleInputs :: Event -> World -> World
handleInputs event w
| gameResult w /= GameInProgress = case event of
(EventKey (SpecialKey KeyEnter) Down _ _) ->
World (resetEnv' fle) GameInProgress
| otherwise = case event of
(EventKey (SpecialKey KeyUp) Down _ _) ->
w {environment = finalEnv MoveUp }
(EventKey (SpecialKey KeyRight) Down _ _) ->
w {environment = finalEnv MoveRight }
(EventKey (SpecialKey KeyDown) Down _ _) ->
w {environment = finalEnv MoveDown }
(EventKey (SpecialKey KeyLeft) Down _ _) ->
w {environment = finalEnv MoveLeft }
_ -> w
where
fle = environment w
finalEnv action =
let (fe, _, _) = stepEnv' action fle
in fe
The last step is rendering the environment with a draw
function. This just requires a working knowledge of constructing the Picture
type in Gloss. It's a little tedious, so I've included the full implementation as an appendix at the bottom. We can then combine all these pieces like so:
main :: IO ()
main = do
env <- basicEnv
play windowDisplay white 20
(World env GameInProgress)
drawEnvironment
handleInputs
updateWorldTime
After we have all these pieces, we can run our game, moving our player around to reach the green tile while avoiding the black tiles!
Conclusion
With a little more plumbing, it would be possible to combine this with the rest of our "Environment" work. There are some definite challenges. Our current environment setup doesn't have a "time update" function. Combining machine learning with Gloss rendering would also be interesting. This is the end of our Open Gym series for now, but I'll definitely be working on this project more in the future! Next week we'll have a summary and review what we've learned!
Take a look at our Github repository to see all the code we wrote in this series! The code for this article is on the gloss
branch. And don't forget to Subscribe to Monday Morning Haskell to get our monthly newsletter!
Appendix: Rendering Frozen Lake
A lot of numbers here are hard-coded for a 4x4 grid, where each cell is 100x100. Notice particularly that we have a text message if we've won or lost.
windowDisplay :: Display
windowDisplay = InWindow "Window" (400, 400) (10, 10)
drawEnvironment :: World -> Picture
drawEnvironment world
| gameResult world == GameWon = Translate (-150) 0 $ Scale 0.12 0.25
(Text "You've won! Press enter to restart!")
| gameResult world == GameLost = Translate (-150) 0 $ Scale 0.12 0.25
(Text "You've lost :( Press enter to restart.")
| otherwise = Pictures [tiles, playerMarker]
where
observationToCoords :: Word -> (Word, Word)
observationToCoords w = quotRem w 4
renderTile :: (Word, TileType) -> Picture
renderTile (obs, tileType ) =
let (centerX, centerY) = rowColToCoords . observationToCoords $ obs
color' = case tileType of
Goal -> green
Hole -> black
_ -> blue
in Translate centerX centerY (Color color' (Polygon [(-50, -50), (-50, 50), (50, 50), (50, -50)]))
tiles = Pictures $ map renderTile (A.assocs (grid . environment $ world))
(px, py) = rowColToCoords . observationToCoords $ (currentObservation . environment $ world)
playerMarker = translate px py (Color red (ThickCircle 10 3))
rowColToCoords :: (Word, Word) -> (Float, Float)
rowColToCoords (row, col) = (100 * (fromIntegral col - 1.5), 100 * (1.5 - fromIntegral row))
Training our Agent with Haskell!
In the previous part of the series, we used the ideas of Q-Learning together with TensorFlow. We got a more general solution to our agent that didn't need a table for every state of the game.
This week, we'll take the final step and implement this TensorFlow approach in Haskell. We'll see how to integrate this library with our existing Environment
system. It works out quite smoothly, with a nice separation between our TensorFlow logic and our normal environment logic!
This article requires a working knowledge of the Haskell TensorFlow integration. If you're new to this, you should download our Guide showing how to work with this framework. You can also read our original Machine Learning Series for some more details! In particular, the second part will go through the basics of tensors.
Building Our TF Model
The first thing we want to do is construct a "model". This model type will store three items. The first will be the tensor for the weights we have. Then the second two will be functions in the TensorFlow Session
monad. The first function will provide scores for the different moves in a position, so we can choose our move. The second will allow us to train the model and update the weights.
data Model = Model
{ weightsT :: Variable Float
, chooseActionStep :: TensorData Float -> Session (Vector Float)
, learnStep :: TensorData Float -> TensorData Float -> Session ()
}
The input for choosing an action is our world observation state, converted to a Float
and put in a size 16-vector. The result will be 4 floating point values for the scores. Then our learning step will take in the observation as well as a set of 4 values. These are the "target" values we're training our model on.
We can construct our model within the Session
monad. In the first part of this process we define our weights and use them to determine the score of each move (results).
createModel :: Session Model
createModel = do
-- Choose Action
inputs <- placeholder (Shape [1, 16])
weights <- truncatedNormal (vector [16, 4]) >>= initializedVariable
let results = inputs `matMul` readValue weights
returnedOutputs <- render results
...
Now we make our "trainer". Our "loss" function is the reduced, squared difference between our results and the "target" outputs. We'll use the adam
optimizer to learn values for our weights to minimize this loss.
createModel :: Session Model
createModel = do
-- Choose Action
...
-- Train Nextwork
(nextOutputs :: Tensor Value Float) <- placeholder (Shape [4, 1])
let (diff :: Tensor Build Float) = nextOutputs `sub` results
let (loss :: Tensor Build Float) = reduceSum (diff `mul` diff)
trainer_ <- minimizeWith adam loss [weights]
...
Finally, we wrap these tensors into functions we can call using runWithFeeds
. Recall that each feed
provides us with a way to fill in one of our placeholder tensors.
createModel :: Session Model
createModel = do
-- Choose Action
...
-- Train Network
...
-- Create Model
let chooseStep = \inputFeed ->
runWithFeeds [feed inputs inputFeed] returnedOutputs
let trainStep = \inputFeed nextOutputFeed ->
runWithFeeds [ feed inputs inputFeed
, feed nextOutputs nextOutputFeed
]
trainer_
return $ Model weights chooseStep trainStep
Our model now wraps all the different tensor operations we need! All we have to do is provide it with the correct TensorData
. To see how that works, let's start integrating with our EnvironmentMonad
!
Integrating With Environment
Our model's functions exist within the TensorFlow monad Session
. So how then, do we integrate this with our existing Environment
code? The answer is, of course, to construct a new monad! This monad will wrap Session
, while still giving us our FrozenLakeEnvironment
! We'll keep the environment within a State
, but we'll also keep a reference to our Model
.
newtype FrozenLake a = FrozenLake
(StateT (FrozenLakeEnvironment, Model) Session a)
deriving (Functor, Applicative, Monad)
instance (MonadState FrozenLakeEnvironment) FrozenLake where
get = FrozenLake (fst <$> get)
put fle = FrozenLake $ do
(_, model) <- get
put (fle, model)
Now we can start implementing the actual EnvironmentMonad
instance. Most of our existing types and functions will work with trivial modification. The only real change is that runEnv
will need to run a TensorFlow session and create the model. Then it can use evalStateT
.
instance EnvironmentMonad FrozenLake where
type (Observation FrozenLake) = FrozenLakeObservation
type (Action FrozenLake) = FrozenLakeAction
type (EnvironmentState FrozenLake) = FrozenLakeEnvironment
baseEnv = basicEnv
currentObservation = currentObs <$> get
resetEnv = resetFrozenLake
stepEnv = stepFrozenLake
runEnv env (FrozenLake action) = runSession $ do
model <- createModel
evalStateT action (env, model)
This is all we need to define the first class. But, with TensorFlow, our environment is only useful if we use the tensor model! This means we need to fill in LearningEnvironment
as well. This has two functions, chooseActionBrain
and learnEnv
using our tensors. Let's see how that works.
Choosing an Action
Choosing an action is straightforward. We'll once again start with the same format for sometimes choosing a random move:
chooseActionTensor :: FrozenLake FrozenLakeAction
chooseActionTensor = FrozenLake $ do
(fle, model) <- get
let (exploreRoll, gen') = randomR (0.0, 1.0) (randomGenerator fle)
if exploreRoll < flExplorationRate fle
then do
let (actionRoll, gen'') = Rand.randomR (0, 3) gen'
put $ (fle { randomGenerator = gen'' }, model)
return (toEnum actionRoll)
else do
...
As in Python, we'll need to convert an observation to a tensor type. This time, we'll create TensorData
. This type wraps a vector, and our input should have the size 1x16. It has the format of a oneHot
tensor. But it's easier to make this a pure function, rather than using a TensorFlow monad.
obsToTensor :: FrozenLakeObservation -> TensorData Float
obsToTensor obs = encodeTensorData (Shape [1, 16]) (V.fromList asList)
where
asList = replicate (fromIntegral obs) 0.0 ++
[1.0] ++
replicate (fromIntegral (15 - obs)) 0.0
Since we've already defined our chooseAction
step within the model, it's easy to use this! We convert the current observation, get the result values, and then pick the best index!
chooseActionTensor :: FrozenLake FrozenLakeAction
chooseActionTensor = FrozenLake $ do
(fle, model) <- get
-- Random move
...
else do
let obs1 = currentObs fle
let obs1Data = obsToTensor obs1
-- Use model!
results <- lift ((chooseActionStep model) obs1Data)
let bestMoveIndex = V.maxIndex results
put $ (fle { randomGenerator = gen' }, model)
return (toEnum bestMoveIndex)
Learning From the Environment
One unfortunate part of our current design is that we have to repeat some work in our learning function. To learn from our action, we need to use all the values, not just the chosen action. So to start our learning function, we'll call chooseActionStep
again. This time we'll get the best index AND the max score.
learnTensor ::
FrozenLakeObservation -> FrozenLakeObservation ->
Reward -> FrozenLakeAction ->
FrozenLake ()
learnTensor obs1 obs2 (Reward reward) action = FrozenLake $ do
model <- snd <$> get
let obs1Data = obsToTensor obs1
-- Use the model!
results <- lift ((chooseActionStep model) obs1Data)
let (bestMoveIndex, maxScore) =
(V.maxIndex results, V.maximum results)
...
We can now get our "target" values by substituting in the reward and max score at the proper index. Then we convert the second observation to a tensor, and we have all our inputs to call our training step!
learnTensor ::
FrozenLakeObservation -> FrozenLakeObservation ->
Reward -> FrozenLakeAction ->
FrozenLake ()
learnTensor obs1 obs2 (Reward reward) action = FrozenLake $ do
...
let (bestMoveIndex, maxScore) =
(V.maxIndex results, V.maximum results)
let targetActionValues = results V.//
[(bestMoveIndex, double2Float reward + (gamma * maxScore))]
let obs2Data = obsToTensor obs2
let targetActionData = encodeTensorData
(Shape [4, 1])
targetActionValues
-- Use the model!
lift $ (learnStep model) obs2Data targetActionData
where
gamma = 0.81
Using these two functions, we can now fill in our LearningEnvironment
class!
instance LearningEnvironment FrozenLake where
chooseActionBrain = chooseActionTensor
learnEnv = learnTensor
-- Same as before
explorationRate = ..
reduceExploration = ...
We'll then be able to run this code just as we would our other Q-learning examples!
Conclusion
This wraps up the machine learning part of this series. We'll have one more article about Open Gym next week. We'll compare our current setup and the Gloss library. Gloss offers much more extensive possibilities for rendering our game and accepting input. So using it would expand the range of games we could play!
We'll definitely continue to expand on the Open Gym concept in the future! Expect a more formal approach to this at some point! For now, take a look at our Github repository for this series! This article's code is on the tensorflow
branch!
Q-Learning with Tensors
In our last article we finished refactoring our Gym code to use a type family. This would make it much easier to add new games to our framework in the future. We're now in the closing stages of this series on AI and agent development. This week we're going to incorporate TensorFlow and perform some more advanced techniques.
We've used Q-Learning to train some agents to play simple games like Frozen Lake and Blackjack. Our existing approach uses an exhaustive table from observations to expected rewards. But in most games we won't be able to construct such an exhaustive table. The observation space will be too large, or it will be continuous. So in this article, we're going to explore how to use TensorFlow to build a more generic function we can learn. We'll start this process in Python, where there's a bit less overhead.
Next up, we'll be using TensorFlow with our Haskell code. We'll explore an alternative form of our FrozenLake
monad using this approach. To make sure you're ready for it, download our Haskell TensorFlow Guide.
A Q-Function
Our goal here will be to make a more general Q-Function, instead of using a table. A Q-Function provides another way of writing our chooseAction
function. With the table approach, each of the 16 possible observations had 4 scores, one for each of the actions we can take. To choose an action, we just take the index with the highest score.
We now want to incorporate a simple neural network for chooseAction
. In our example, this network will consist of a single matrix of weights. The input to our network will be a vector of size 16. This vector will have all zeroes, except for the index of the current observation, which will be 1. Then the output of the network will be a vector of size 4. These will give the scores for each move from that observation. So our "weights" will have size 16x4.
So one useful helper function we can write already will be to convert an observation to an input tensor. This will make use of the identity matrix.
def obs_to_tensor(obs):
return np.identity(16)[obs:obs+1]
Building the Graph
We can now go ahead and start building our tensor graph. We'll start with the part that makes moves from an observation. For this quick Python script, we'll let the tensors live in the global namespace.
import gym
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
env = gym.make('FrozenLake-v0')
inputs = tf.placeholder(shape=[1,16], dtype=tf.float32)
weights = tf.Variable(tf.random_uniform([16, 4], 0, 0.01))
output = tf.matmul(inputs, weights)
prediction = tf.argmax(output, 1)
Each time we make a move, we'll pass the current observation tensor as the input placeholder. Then we multiply it by the weights to get scores for each different output action. Our final "prediction" is the output index with the highest weight. Notice how we initialize our network with random weights. This helps prevent our network from getting stuck early on.
We can use these tensors to construct our choose_action
function. This will, of course take the current observation as an input. But it will also take an epsilon
value for the random move probability. We use sess.run
to run our prediction and output tensors. If we choose a random move instead, we'll replace the actual "action" with a sample from the action space.
def choose_action(input_obs, epsilon):
action, all_outputs = sess.run(
[prediction, output],
feed_dict={inputs: obs_to_tensor(input_obs)})
if np.random.rand(1) < epsilon:
action[0] = env.action_space.sample()
return action, all_outputs
The Learning Process
The first part of our graph tells us how to make moves, but we also need to update our weights so the network gets better! To do this, we'll add a few more tensors.
next_output = tf.placeholder(shape=[1,4], dtype=tf.float32)
loss = tf.reduce_sum(tf.square(next_output - output))
trainer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
update_model = trainer.minimize(loss)
init = tf.initialize_all_variables()
Let's go through these one-by-one. We need to take an extra input for the target values, which incorporate the "next" state of the game. We want the values we get in the original state to be closer to those! So our "loss" function is the squared difference of our "current" output and the "target" output. Then we create a "trainer" that minimizes the loss function. Because our weights are the "variable" in the system, they'll get updated to minimize this loss.
We can use this section group of tensors to construct our "learning" function.
def learn_env(current_obs, next_obs, reward, action, all_outputs):
gamma = 0.81
_, all_next_outputs = choose_action(next_obs, 0.0)
next_max = np.max(all_next_outputs)
target_outputs = all_outputs
target_outputs[0, action[0]] = reward + gamma * next_max
sess.run(
[update_model, weights],
feed_dict={inputs: obs_to_tensor(current_obs),
next_output: target_outputs})
We start by choosing an action from the "next" position (without randomness). We get the largest value from that choice. We use this and the reward to inform our "target" of what the current input weights should be. In other words, taking our action should give us the reward and the best value we would get from the next position. Then we update our model!
Playing the Game
Now all that's left is to play out the game! This looks a lot like code from previous parts, so we won't go into too much depth. The key section is in the middle of the loop. We choose our next action, use it to step the environment, and use the reward to learn.
rewards_list = []
with tf.Session() as sess:
sess.run(init)
epsilon = 0.9
decay_rate = 0.9
num_episodes = 10000
for i in range(num_episodes):
# Reset environment and get first new observation
current_obs = env.reset()
sum_rewards = 0
done = False
num_steps = 0
while num_steps < 100:
num_steps += 1
# Choose, Step, Learn!
action, all_outputs = choose_action(current_obs, epsilon)
next_obs, reward, done, _ = env.step(action[0])
learn_env(current_obs, next_obs, reward, action, all_outputs)
sum_rewards += reward
current_obs = next_obs
if done == True:
if i % 100 == 99:
epsilon *= decay_rate
break
rewards_list.append(sum_rewards)
Our results won't be quite as good as the table approach. Using a tensor function allows our system to be a lot more general. But the consequence of this is that the results aren't stable. We could, of course, improve the results by using more advanced algorithms. But we'll get into that another time!
Conclusion
Now that we know the core ideas behind using tensors for Q-Learning, it's time to do this in Haskell. Next week, we'll do a refresher on how Haskell operates together with Tensor Flow. We'll see how we can work these ideas into our existing Environment
framework.
Refactored Game Play!
Last week, we implemented Q-learning for our Blackjack game. We found the solution looked a lot like Frozen Lake for the most part. So we created a new class EnvironmentMonad
to combine the steps these games have in common. This week, we'll see a full implementation of that class. Our goal is a couple generic gameLoop
functions we can use for different modes of our game.
As always, the code for this article is on our Github repository! You'll mainly want to explore any of the source files with Environment
in their name.
Expanding our Environment
Last time, we put together a basic idea of what a generic environment could look like. We made a couple separate "sub-classes" as well, for rendering and learning.
class (Monad m) => EnvironmentMonad m where
type Observation m :: *
type Action m :: *
resetEnv :: m (Observation m)
stepEnv :: (Action m) -> m (Observation m, Reward, Bool)
class (MonadIO m, EnvironmentMonad m) => RenderableEnvironment m where
renderEnv :: m ()
class (EnvironmentMonad m) => LearningEnvironment m where
learnEnv ::
(Observation m) -> (Observation m) -> Reward -> (Action m) -> m ()
There are still a couple extra pieces we can add that will make these classes more complete. One thing we're missing here is a concrete expression of our state. This makes it difficult to run our environments from normal code. So let's add a new type to the family for our "Environment" type, as well as a function to "run" that environment. We'll also want a generic way to get the current observation.
class (Monad m) => EnvironmentMonad m where
type Observation m :: *
type Action m :: *
type EnvironmentState m :: *
runEnv :: (EnvironmentState m) -> m a -> IO a
currentObservation :: m (Observation m)
resetEnv :: m (Observation m)
stepEnv :: (Action m) -> m (Observation m, Reward, Bool)
Forcing run
to use IO
is more restrictive than we'd like. In the future we might explore how to get our environment to wrap a monad parameter to fix this.
We can also add a couple items to our LearningEnvironment
for the exploration rate. This way, we don't need to do anything concrete to affect the learning process. We'll also make the function for choosing an action is a specific part of the environment.
class (EnvironmentMonad m) => LearningEnvironment m where
learnEnv ::
(Observation m) -> (Observation m) -> Reward -> (Action m) -> m ()
chooseActionBrain :: m (Action m)
explorationRate :: m Double
reduceExploration :: Double -> Double -> m ()
Game Loops
In previous iterations, we had gameLoop
functions for each of our different environments. We can now write these in a totally generic way! Here's a simple loop that plays the game once and produces a result:
gameLoop :: (EnvironmentMonad m) =>
m (Action m) -> m (Observation m, Reward)
gameLoop chooseAction = do
newAction <- chooseAction
(newObs, reward, done) <- stepEnv newAction
if done
then return (newObs, reward)
else gameLoop chooseAction
If we want to render the game between moves, we add a single renderEnv
call before selecting the move. We also need an extra IO
constraint and to render it before returning the final result.
gameRenderLoop :: (RenderableEnvironment m) =>
m (Action m) -> m (Observation m, Reward)
gameRenderLoop chooseAction = do
renderEnv
newAction <- chooseAction
(newObs, reward, done) <- stepEnv newAction
if done
then renderEnv >> return (newObs, reward)
else gameRenderLoop chooseAction
Finally, there are a couple different loops we can write for a learning environment. We can have a generic loop for one iteration of the game. Notice how we rely on the class function chooseActionBrain
. This means we don't need such a function as a parameter.
gameLearningLoop :: (LearningEnvironment m) =>
m (Observation m, Reward)
gameLearningLoop = do
oldObs <- currentObservation
newAction <- chooseActionBrain
(newObs, reward, done) <- stepEnv newAction
learnEnv oldObs newObs reward newAction
if done
then return (newObs, reward)
else gameLearningLoop
Then we can make another loop that runs many learning iterations. We reduce the exploration rate at a reasonable interval.
gameLearningIterations :: (LearningEnvironment m) => m [Reward]
gameLearningIterations = forM [1..numEpisodes] $ \i -> do
resetEnv
when (i `mod` 100 == 99) $ do
reduceExploration decayRate minEpsilon
(_, reward) <- gameLearningLoop
return reward
where
numEpisodes = 10000
decayRate = 0.9
minEpsilon = 0.01
Concrete Implementations
Now we want to see how we actually implement these classes for our types. We'll show the examples for FrozenLake
but it's an identical process for Blackjack
. We start by defining the monad type as a wrapper over our existing state.
newtype FrozenLake a = FrozenLake (StateT FrozenLakeEnvironment IO a)
deriving (Functor, Applicative, Monad)
We'll want to make a State
instance for our monads over the environment type. This will make it easier to port over our existing code. We'll also need a MonadIO
instance to help with rendering.
instance (MonadState FrozenLakeEnvironment) FrozenLake where
get = FrozenLake get
put fle = FrozenLake $ put fle
instance MonadIO FrozenLake where
liftIO act = FrozenLake (liftIO act)
Then we want to change our function signatures to live in the desired monad. We can pretty much leave the functions themselves untouched.
resetFrozenLake :: FrozenLake FrozenLakeObservation
stepFrozenLake ::
FrozenLakeAction -> FrozenLake (FrozenLakeObservation, Reward, Bool)
renderFrozenLake :: FrozenLake ()
Finally, we make the actual instance for the class. The only thing we haven't defined yet is the runEnv
function. But this is a simple wrapper for evalStateT
.
instance EnvironmentMonad FrozenLake where
type (Observation FrozenLake) = FrozenLakeObservation
type (Action FrozenLake) = FrozenLakeAction
type (EnvironmentState FrozenLake) = FrozenLakeEnvironment
baseEnv = basicEnv
runEnv env (FrozenLake action) = evalStateT action env
currentObservation = FrozenLake (currentObs <$> get)
resetEnv = resetFrozenLake
stepEnv = stepFrozenLake
instance RenderableEnvironment FrozenLake where
renderEnv = renderFrozenLake
There's a bit more we could do. We could now separate the "brain" portions of the environment without any issues. We wouldn't need to keep the Q-Table and the exploration rate in the state. This would improve our encapsulation. We could also make our underlying monads more generic.
Playing the Game
Now, playing our game is simple! We get our basic environment, reset it, and call our loop function! This code will let us play one iteration of Frozen Lake, using our own input:
main :: IO ()
main = do
(env :: FrozenLakeEnvironment) <- basicEnv
_ <- runEnv env action
putStrLn "Done!"
where
action = do
resetEnv
(gameRenderLoop chooseActionUser
:: FrozenLake (FrozenLakeObservation, Reward))
Once again, we can make this code work for Blackjack with a simple name substitution.
We can also make this work with our Q-learning code as well. We start with a simple instance for LearningEnvironment
.
instance LearningEnvironment FrozenLake where
learnEnv = learnQTable
chooseActionBrain = chooseActionQTable
explorationRate = flExplorationRate <$> get
reduceExploration decayRate minEpsilon = do
fle <- get
let e = flExplorationRate fle
let newE = max minEpsilon (e * decayRate)
put $ fle { flExplorationRate = newE }
And now we use gameLearningIterations
instead of gameRenderLoop
!
main :: IO ()
main = do
(env :: FrozenLakeEnvironment) <- basicEnv
_ <- runEnv env action
putStrLn "Done!"
where
action = do
resetEnv
(gameLearningIterations :: FrozenLake [Reward])
Conclusion
We're still pulling in two "extra" pieces besides the environment class itself. We still have specific implementations for basicEnv
and action choosing. We could try to abstract these behind the class as well. There would be generic functions for choosing the action as a human and choosing at random. This would force us to make the action space more general as well.
But for now, it's time to explore some more interesting learning algorithms. For our current Q-learning approach, we make a table with an entry for every possible game state. This doesn't scale to games with large or continuous observation spaces! Next week, we'll see how TensorFlow allows us to learn a Q function instead of a direct table.
We'll start in Python, but soon enough we'll be using TensorFlow in Haskell. Take a look at our guide for help getting everything installed!
Generalizing Our Environments
In our previous episode, we used Q-Learning to find a solution for the Frozen Lake scenario. We also have a Blackjack game that shares a lot of core ideas with Frozen Lake.
So in this part, we're going to start by applying our Q-Learning solution to the Blackjack game. This will highlight the similarities in the code between the two games. But we'll also see a few differences. The similarities will lead us to create a typeclass for our environment concept. Each "difference" in the two systems will suggest an expression that must be part of the class. Let's explore the implications of this.
Adding to the Environment
Once again, we will need to express our Q-table and the exploration rate as part of the environment. But this time, the index of our Q-Table will need to be a bit more complex. Remember our observation now has three different parts: the user's score, whether the player has an ace, and the dealer's show-card. We can turn each of these into a Word
, and combine them with the action itself. This gives us an index with four Word
values.
We want to populate this array with bounds to match the highest value in each of those fields.
data BlackjackEnvironment = BlackjackEnvironment
{ ...
, qTable :: A.Array (Word, Word, Word, Word) Double
, explorationRate :: Double
} deriving (Show)
basicEnv :: IO BlackjackEnvironment
basicEnv = do
gen <- Rand.getStdGen
let (d, newGen) = shuffledDeck gen
return $ BlackjackEnvironment
...
(A.listArray ((0,0,0,0), (30, 1, 12, 1)) (repeat 0.0))
1.0
While we're at it, let's create a function to turn an Observation/Action combination into an index.
makeQIndex :: BlackjackObservation -> BlackjackAction
-> (Word, Word, Word, Word)
makeQIndex (BlackjackObservation pScore hasAce dealerCard) action =
( pScore
, if hasAce then 1 else 0
, fromIntegral . fromEnum $ dealerCard
, fromIntegral . fromEnum $ action
)
With the help of this function, it's pretty easy to re-use most of our code from last time! The action choice function and the learning function look almost the same! So review last week's article (or the code on Github) for details.
Using the Same Game Loop
With our basic functions out of the way, let's now turn our attention to the game loop and running functions. For the game loop, we don't have anything too complicated. It's a step-by-step process.
- Retrieve the current observation
- Choose the next action
- Use this action to step the environment
- Use our "learning" function to update the Q-Table
- If we're done, return the reward. Otherwise recurse.
Here's what it looks like. Recall that we're taking our action choice function as an input. All our functions live in a similar monad, so this is pretty easy.
gameLoop :: (MonadIO m) =>
StateT BlackjackEnvironment m BlackjackAction ->
StateT BlackjackEnvironment m (BlackjackObservation, Double)
gameLoop chooseAction = do
oldObs <- currentObservation <$> get
newAction <- chooseAction
(newObs, reward, done) <- stepEnv newAction
learnQTable oldObs newObs reward newAction
if done
then do
if reward > 0.0
then liftIO $ putStrLn "Win"
else liftIO $ putStrLn "Lose"
return (newObs, reward)
else gameLoop chooseAction
Now to produce our final output and run game iterations, we need a little wrapper code. We create (and reset) our initial environment. Then we pass it to an action that runs the game loop and reduces the exploration rate when necessary.
playGame :: IO ()
playGame = do
env <- basicEnv
env' <- execStateT resetEnv env
void $ execStateT stateAction env'
where
numEpisodes = 10000
decayRate = 1.0
minEpsilon = 0.01
stateAction :: StateT BlackjackEnvironment IO ()
stateAction = do
rewards <- forM [1..numEpisodes] $ \i -> do
resetEnv
when (i `mod` 100 == 99) $ do
bje <- get
let e = explorationRate bje
let newE = max minEpsilon (e * decayRate)
put $ bje { explorationRate = newE }
(_, reward) <- gameLoop chooseActionQTable
return reward
lift $ print (sum rewards)
Now we can play our game! Even with learning, we'll still only get around 40% of the points available. Blackjack is a tricky, luck-based game, so this isn't too surprising.
Constructing a Class
Now if you look very carefully at the above code, it should almost work for Frozen Lake as well! We'd only need to make a few adjustments to naming types. This tells us we have a general structure between our different games. And we can capture that structure with a class.
Let's look at the common elements between our environments. These are all functions we call from the game loop or runner:
- Resetting the environment
- Stepping the environment (with an action)
- Rendering the environment (if necessary)
- Apply some learning method on the new data
- Diminish the exploration rate
So our first attempt at this class might look like this, looking only at the most important fields:
class Environment e where
resetEnv :: (Monad m) => StateT e m Observation
stepEnv :: (Monad m) => Action
-> StateT e m (Observation, Double, Bool)
renderEnv :: (MonadIO m) => StateT e m ()
learnEnv :: (Monad m) =>
Observation -> Observation -> Double -> Action -> StateT e m ()
instance Environment FrozenLakeEnvironment where
...
instance Environment BlackjackEnvironment where
...
We can make two clear observations about this class. First, we need to generalize the Observation
and Action
types! These are different in our two games and this isn't reflected above. Second, we're forcing ourselves to use the State
monad over our environment. This isn't necessarily wise. It might force us to add extra fields to the environment type that don't belong there.
The solution to the first issue is to make this class a type family! Then we can associate the proper data types for observations and actions. The solution to the second issue is that our class should be over a monad instead of the environment itself.
Remember, a monad provides the context in which a computation takes place. So in our case, our game, with all its stepping and learning, is that context!
Doing this gives us more flexibility for figuring out what data should live in which types. It makes it easier to separate the game's internal state from auxiliary state, like the exploration rate.
Here's our second try, with associated types and a monad.
newtype Reward = Reward Double
class (MonadIO m) => EnvironmentMonad m where
type Observation m :: *
type Action m :: *
resetEnv :: m (Observation m)
currentObservation :: m (Observation m)
stepEnv :: (Action m) -> m (Observation m, Reward, Bool)
renderEnv :: m ()
learnEnv ::
(Observation m) -> (Observation m) ->
Reward -> (Action m) -> m ()
explorationRate :: m Double
reduceExploration :: Double -> Double -> m ()
There are a couple undesirable parts of this. Our monad has to be IO
to account for rendering. But it's possible for us to play the game without needing to render. In fact, it's also possible for us to play the game without learning!
So we can separate this into more typeclasses! We'll have two "subclasses" of our Environment
. We'll make a separate class for rendering. This will be the only class that needs an IO
constraint. Then we'll have a class for learning functionality. This will allow us to "run" the game in different contexts and limit the reach of these effects.
newtype Reward = Reward Double
class (Monad m) => EnvironmentMonad m where
type Observation m :: *
type Action m :: *
currentObservation :: m (Observation m)
resetEnv :: m (Observation m)
stepEnv :: (Action m) -> m (Observation m, Reward, Bool)
class (MonadIO m, EnvironmentMonad m) =>
RenderableEnvironment m where
renderEnv :: m ()
class (EnvironmentMonad m) => LearningEnvironment m where
learnEnv ::
(Observation m) -> (Observation m) ->
Reward -> (Action m) -> m ()
explorationRate :: m Double
reduceExploration :: Double -> Double -> m ()
Conclusion
Next week we'll explore how to implement these classes for our different games! We'll end up with a totally generic function for playing the game. We'll have a version with learning and a version without!
The next step after this will be to attach more sophisticated learning mechanisms. Soon, we'll explore how to expand our Q-Learning beyond simple discrete states. The way to do this is to use tensors! So in a couple weeks, we'll explore how to use TensorFlow to construct a function for Q-Learning. To get ready, download our Haskell TensorFlow Guide!
Frozen Lake with Q-Learning!
In the last few weeks, we've written two simple games in Haskell: Frozen Lake and Blackjack. These games are both toy examples from the Open AI Gym. Now that we've written the games, it's time to explore more advanced ways to write agents for them.
In this article, we'll explore the concept of Q-Learning. We've talked about this idea on the MMH blog before. But now we'll see it in action in a simpler context than we did before. We'll write a little bit of Python code, following some examples for Frozen Lake. Then we'll try to implement the same ideas in Haskell. Along the way, we'll see more patterns emerge about our games' interfaces.
We won't be using Tensorflow in the article. But we'll soon explore ways to augment our agent's capabilities with this library! To learn about Haskell and Tensorflow, download our TensorFlow guide!
Making a Q-Table
Let's start by taking a look at this basic Python implementation of Q-Learning for Frozen Lake. This will show us the basic ideas of Q-Learning. We start out by defining a few global parameters, as well as Q
, a variable that will hold a table of values.
epsilon = 0.9
min_epsilon = 0.01
decay_rate = 0.9
Total_episodes = 10000
max_steps = 100
learning_rate = 0.81
gamma = 0.96
env = gym.make('FrozenLake-v0')
Q = numpy.zeros((env.observation_space.n, env.action_space.n))
Recall that our environment has an action space and an observation space. For this basic version of the Frozen Lake game, an observation is a discrete integer value from 0 to 15. This represents the location our character is on. Then the action space is an integer from 0 to 3, for each of the four directions we can move. So our "Q-table" will be an array with 16 rows and 4 columns.
How does this help us choose our move? Well, each cell in this table has a score. This score tells us how good a particular move is for a particular observation state. So we could define a choose_action
function in a simple way like so:
def choose_action(observation):
return numpy.argmax(Q[observation, :])
This will look at the different values in the row for this observation, and choose the highest index. So if the "0" value in this row is the highest, we'll return 0, indicating we should move left. If the second value is highest, we'll return 1, indicating a move down.
But we don't want to choose our moves deterministically! Our Q-Table starts out in the "untrained" state. And we need to actually find the goal at least once to start back-propagating rewards into our maze. This means we need to build some kind of exploration into our system. So each turn, we can make a random move with probability epsilon
.
def choose_action(observation):
action = 0
if np.random.uniform(0, 1) < epsilon:
action = env.action_space.sample()
else:
action = numpy.argmax(Q[observation, :])
return action
As we learn more, we'll diminish the exploration probability. We'll see this below!
Updating the Table
Now, we also want to be able to update our table. To do this, we'll write a function that follows the Q-learning rule. It will take two observations, the reward for the second observation, and the action we took to get there.
def learn(observation, observation2, reward, action):
prediction = Q[observation, action]
target = reward + gamma * numpy.max(Q[observation2, :])
Q[observation, action] = Q[observation, action] +
learning_rate * (target - prediction)
For more details on what happens here, read our Q-Learning primer. But there's one general rule.
Suppose we move from Observation O1
to Observation O2
with action A
. We want the Q-table value for the pair (O1, A)
to be closer to the best value we can get from O2
. And we want to factor in the potential reward we can get by moving to O2
. Thus our goal square should have the reward of 1. And squares near it should have values close to this reward!
Playing the Game
Playing the game now is straightforward, following the examples we've done before. We'll have a certain number of episodes. Within each episode, we make our move, and use the reward to "learn" for our Q-table.
for episode in range(total_episodes):
obs = env.reset()
t = 0
if episode % 100 == 99:
epsilon *= decay_rate
epsilon = max(epsilon, min_epsilon)
while t < max_steps:
action = choose_action(obs)
obs2, reward, done, info = env.step(action)
learn(obs, obs2, reward, action)
obs = obs2
t += 1
if done:
if reward > 0.0:
print("Win")
else:
print("Lose")
break
Notice also how we drop the exploration rate epsilon
every 100 episodes or so. We can run this, and we'll observe that we lose a lot at first. But by the end we're winning more often than not! At the end of the series, it's a good idea to save the Q-table in some sensible way.
Haskell: Adding a Q-Table
To translate this into Haskell, we first need to account for our new pieces of state. Let's extend our environment type to include two more fields. One will be for our Q-table. We'll use an array for this as well, as this gives convenient accessing and updating syntax. The other will be the current exploration rate:
data FrozenLakeEnvironment = FrozenLakeEnvironment
{ ...
, qTable :: A.Array (Word, Word) Double
, explorationRate :: Double
}
Now we'll want to write two primary functions. First, we'll want to choose our action using the Q-Table. Second, we want to be able to update the Q-Table so we can "learn" a good path.
Both of these will use this helper function. It takes an Observation
and the current Q-Table and produces the best score we can get from that location. It also provides us the action index. Note the use of a tuple section to produce indices
.
maxScore ::
Observation ->
A.Array (Word, Word) Double ->
(Double, (Word, Word))
maxScore obs table = maximum valuesAndIndices
where
indices = (obs, ) <$> [0..3]
valuesAndIndices = (\i -> (table A.! i, i)) <$> indices
Using the Q-Table
Now let's see how we produce our actions using this table. As with most of our state functions, we'll start by retrieving the environment. Then we'll get our first roll to see if this is an exploration turn or not.
chooseActionQTable ::
(MonadState FrozenLakeEnvironment m) => m Action
chooseActionQTable = do
fle <- get
let (exploreRoll, gen') = randomR (0.0, 1.0) (randomGenerator fle)
if exploreRoll < explorationRate fle
...
If we're exploring, we do another random roll to pick an action and replace the generator. Otherwise we'll get the best scoring move and derive the Action
from the returned index. In both cases, we use toEnum
to turn the number into a proper Action
.
chooseActionQTable ::
(MonadState FrozenLakeEnvironment m) => m Action
chooseActionQTable = do
fle <- get
let (exploreRoll, gen') = randomR (0.0, 1.0) (randomGenerator fle)
if exploreRoll < explorationRate fle
then do
let (actionRoll, gen'') = Rand.randomR (0, 3) gen'
put $ fle { randomGenerator = gen'' }
return (toEnum actionRoll)
else do
let maxIndex = snd $ snd $
maxScore (currentObservation fle) (qTable fle)
put $ fle {randomGenerator = gen' }
return (toEnum (fromIntegral maxIndex))
The last big step is to write our learning function. Remember this takes two observations, a reward, and an action. We start by getting our predicted value for the original observation. That is, what score did we expect when we made this move?
learnQTable :: (MonadState FrozenLakeEnvironment m) =>
Observation -> Observation -> Double -> Action -> m ()
learnQTable obs1 obs2 reward action = do
fle <- get
let q = qTable fle
actionIndex = fromIntegral . fromEnum $ action
prediction = q A.! (obs1, actionIndex)
...
Now we specify our target
. This combines the reward (if any) and the greatest score we can get from our new observed state. We use these values to get a newValue
, which we put into the Q-Table at the original index. Then we put
the new table into our state.
learnQTable :: (MonadState FrozenLakeEnvironment m) =>
Observation -> Observation -> Double -> Action -> m ()
learnQTable obs1 obs2 reward action = do
fle <- get
let q = qTable fle
actionIndex = fromIntegral . fromEnum $ action
prediction = q A.! (obs1, actionIndex)
target = reward + gamma * (fst $ maxScore obs2 q)
newValue = prediction + learningRate * (target - prediction)
newQ = q A.// [((obs1, actionIndex), newValue)]
put $ fle { qTable = newQ }
where
gamma = 0.96
learningRate = 0.81
And just like that, we're pretty much done! We can slide these new functions right into our existing functions!
Conclusion
The rest of the code is straightforward enough. We make a couple tweaks as necessary to our gameLoop
so that it actually calls our training function. Then we just update the exploration rate at appropriate intervals. Take a look at our code our Github for more details! This week's code is in FrozenLake2.hs
.
We've now got an agent that can play Frozen Lake coherently using Q-Learning! Next time, we'll try to adopt this agent for Blackjack as well. We'll see the similarities between the two games. Then we'll start formulating some ideas to combine the approaches.
Blackjack: Following the Patterns
For a couple weeks now, we've been exploring the basics of Open AI Gym. The Frozen Lake example has been our basic tool so far, and we've now written it in Haskell. We'd like to start training agents for this game soon. But first, we want to make sure we're set up to generalize our idea of an environment.
So this week, we're going to make another small example game. This time, we'll play Blackjack. This will give us an example of an environment that needs a more complex observation state. When we're done with this example, we'll be able to compare our two examples. The end goal is to be able to use the same code to train an algorithm for either of them.
If you want to dive into machine learning, you'll need to understand TensorFlow first! Read this guide to learn how to use TensorFlow with Haskell!
Basic Rules
If you don't know the basic rules of casino blackjack, take a look here. Essentially, we have a deck of cards, and each card has a value. We want to get as high a score as we can without exceeding 21 (a "bust"). Each turn, we want to either "hit" and add another card to our hand, or "stand" and take the value we have.
After we get all our cards, the dealer must then draw cards under specific rules. The dealer must "hit" until their score is 17 or higher, and then "stand". If the dealer busts or our score beats the dealer, we win. If the scores are the same it's a "push".
Here's a basic Card
type we'll work with to represent the card values, as well as their scores.
data Card =
Two | Three | Four | Five |
Six | Seven | Eight | Nine |
Ten | Jack | Queen | King | Ace
deriving (Show, Eq, Enum)
cardScore :: Card -> Word
cardScore Two = 2
cardScore Three = 3
cardScore Four = 4
cardScore Five = 5
cardScore Six = 6
cardScore Seven = 7
cardScore Eight = 8
cardScore Nine = 9
cardScore Ten = 10
cardScore Jack = 10
cardScore Queen = 10
cardScore King = 10
cardScore Ace = 1
The Ace
can count as 1 or 11. We account for this in our scoring functions:
-- Returns the base sum, as well as a boolean if we have
-- a "usable" Ace.
baseScore :: [Card] -> (Word, Bool)
baseScore cards = (score, score <= 11 && Ace `elem` cards)
where
score = sum (cardScore <$> cards)
scoreHand :: [Card] -> Word
scoreHand cards = if hasUsableAce then score + 10 else score
where
(score, hasUsableAce) = baseScore cards
Core Environment Types
As in Frozen Lake, we need to define types for our environment. The "action" type is straightforward, giving only two options for "hit" and "stand":
data BlackjackAction = Hit | Stand
deriving (Show, Eq, Enum)
Our observation is more complex than in Frozen Lake. We have more information that can guide us than just knowing our location. We'll boil it down to three elements. First, we need to know our own score. Second, we need to know if we have an Ace. This isn't clear from the score, and it can give us more options. Last, we need to know what card the dealer is showing.
data BlackjackObservation = BlackjackObservation
{ playerScore :: Word
, playerHasAce :: Bool
, dealerCardShowing :: Card
} deriving (Show)
Now for our environment, we'll once again store the "current observation" as one of its fields.
data BlackjackEnvironment = BlackjackEnvironment
{ currentObservation :: BlackjackObservation
...
}
The main fields are about the cards in play. We'll have a list of cards for our own hand. Then we'll have the main deck to draw from. The dealer's cards will be a 3-tuple. The first is the "showing" card. The second is the hidden card. And the third is a list for extra cards the dealer draws later.
data BlackjackEnvironment = BlackjackEnvironment
{ currentObservation :: BlackjackObservation
, playerHand :: [Card]
, deck :: [Card]
, dealerHand :: (Card, Card, [Card])
...
}
The last pieces of this will be a boolean for whether the player has "stood", and a random generator. The boolean helps us render the game, and the generator helps us reset and shuffle without using IO
.
data BlackjackEnvironment = BlackjackEnvironment
{ currentObservation :: BlackjackObservation
, playerHand :: [Card]
, deck :: [Card]
, dealerHand :: (Card, Card, [Card])
, randomGenerator :: Rand.StdGen
, playerHasStood :: Bool
} deriving (Show)
Now we can use these to write our main game functions. As in Frozen Lake, we'll want functions to render the environment and reset it. We won't go over those in this article. But we will focus on the core step
function.
Playing the Game
Our step
function starts out simply enough. We retrieve our environment and analyze the action we get.
stepEnv :: (Monad m) => BlackjackAction ->
StateT BlackjackEnvironment m (BlackjackObservation, Double, Bool)
stepEnv action = do
bje <- get
case action of
Stand -> ...
Hit -> ...
Below, we'll write a function to play the dealer's hand. So for the Stand
branch, we'll update the state variable for the player standing, and call that helper.
stepEnv action = do
bje <- get
case action of
Stand -> do
put $ bje { playerHasStood = True }
playOutDealerHand
Hit -> ...
When we hit, we need to determine the top card in the deck. We'll add this to our hand to get the new player score. All this information goes into our new observation, and the new state of the game.
stepEnv action = do
bje <- get
case action of
Stand -> ...
Hit -> do
let (topCard : remainingDeck) = deck bje
pHand = playerHand bje
currentObs = currentObservation bje
newPlayerHand = topCard : pHand
newScore = scoreHand newPlayerHand
newObservation = currentObs
{ playerScore = newScore
, playerHasAce = playerHasAce currentObs ||
topCard == Ace}
put $ bje { currentObservation = newObservation
, playerHand = newPlayerHand
, deck = remainingDeck }
...
Now we need to analyze the player's score. If it's greater than 21, we've busted. We return a reward of 0.0 and we're done. If it's exactly 21, we'll treat that like a "stand" and play out the dealer. Otherwise, we'll continue by returning False
.
stepEnv action = do
bje <- get
case action of
Stand -> ...
Hit -> do
...
if newScore > 21
then return (newObservation, 0.0, True)
else if newScore == 21
then playOutDealerHand
else return (newObservation, 0.0, False)
Playing out the Dealer
To wrap up the game, we need to give cards to the dealer until their score is high enough. So let's start by getting the environment and scoring the dealer's current hand.
playOutDealerHand :: (Monad m) =>
StateT BlackjackEnvironment m (BlackjackObservation, Double, Bool)
playOutDealerHand = do
bje <- get
let (showCard, hiddenCard, restCards) = dealerHand bje
currentDealerScore = scoreHand (showCard : hiddenCard : restCards)
If the dealer's score is less than 17, we can draw the top card, add it to their hand, and recurse.
playOutDealerHand :: (Monad m) => StateT BlackjackEnvironment m (BlackjackObservation, Double, Bool)
playOutDealerHand = do
...
if currentDealerScore < 17
then do
let (topCard : remainingDeck) = deck bje
put $ bje { dealerHand =
(showCard, hiddenCard, topCard : restCards)
, deck = remainingDeck}
playOutDealerHand
else ...
Now all that's left is analyzing the end conditions. We'll score the player's hand and compare it to the dealer's. If the dealer has busted, or the player has the better score, we'll give a reward of 1.0. If they're the same, the reward is 0.5. Otherwise, the player loses. In all cases, we return the current observation and True
as our "done" variable.
playOutDealerHand :: (Monad m) => StateT BlackjackEnvironment m (BlackjackObservation, Double, Bool)
playOutDealerHand = do
bje <- get
let (showCard, hiddenCard, restCards) = dealerHand bje
currentDealerScore = scoreHand
(showCard : hiddenCard : restCards)
if currentDealerScore < 17
then ...
else do
let playerScore = scoreHand (playerHand bje)
currentObs = currentObservation bje
if playerScore > currentDealerScore || currentDealerScore > 21
then return (currentObs, 1.0, True)
else if playerScore == currentDealerScore
then return (currentObs, 0.5, True)
else return (currentObs, 0.0, True)
Odds and Ends
We'll also need code for running a loop and playing the game. But that code though looks very similar to what we used for Frozen Lake. This is a promising sign for our hopes to generalize this with a type class. Here's a sample playthrough of the game. As inputs, 0
means "hit" and 1
means "stand".
So in this first game, we start with a King and 9, and see the dealer has a 6 showing. We "stand", and the dealer busts.
6 X
K 9
19 # Our current score
1 # Stand command
1.0 # Reward
Episode Finished
6 9 8 # Dealer's final hand
23 # Dealer's final (busted) score
K 9
19
In this next example, we try to hit on 13, since the dealer has an Ace. We bust, and lose the game.
A X
3 J
13
0
0.0
Episode Finished
A X
K 3 J
23
Conclusion
Of course, there are a few ways we could make this more complicated. We could do iterated blackjack to allow card-counting. Or we could add advanced moves like splitting and doubling down. But that's not necessary for our purposes. The main point is that we have two fully functional games we can work with!
Next time, we'll start digging into the machine learning process. We'll see what techniques we can use with the Open Gym in Python and start translating those to Haskell.
We left out quite a bit of code in this example, particularly around setup. Take a look at our Github repository to see all the details!
Frozen Lake in Haskell
Last time on MMH, we began our investigation into Open AI Gym. We started by using the Frozen Lake toy example to learn about environments. An environment is a basic wrapper that has a specific API for manipulating the game.
Last week's work was mostly in Python. But this week, we're going to do a deep dive into Haskell and consider how to write the Frozen Lake example. We'll see all the crucial functions from the Environment API as well as how to play the game. Take a look at our Github repository to see any extra details about this code!
This process will culminate with training agents to complete these games with machine learning. This will involve TensorFlow. So if you haven't already, download our Haskell Tensor Flow Guide. It will teach you how to get the framework up and running on your machine.
Core Types
In the previous part, we started defining our environment with generic values. For example, we included the action space and observation space. For now, we're actually going to make things more specific to the Frozen Lake problem. This will keep our example much simpler for now. In the coming weeks, we'll start examining how to generalize the idea of an environment and spaces.
We need to start with the core types of our application. We'll begin with a TileType
for our board, as well as observations and actions.
data TileType =
Start |
Goal |
Frozen |
Hole
deriving (Show, Eq)
type Observation = Word
data Action =
MoveLeft |
MoveDown |
MoveRight |
MoveUp
deriving (Show, Eq, Enum)
As in Python, each observation will be a single number indicating where we are on the board. We'll have four different actions. The Enum
instance will help us convert between these constructors and numbers.
Now let's consider the different elements we actually need within the environment. The game's main information is the grid of tiles. We'll store this as an Array
. The indices will be our observation values, and the elements will be the TileType
. For convenience, we'll also store the dimensions of our grid:
data FrozenLakeEnvironment = FrozenLakeEnvironment
{ grid :: Array Word TileType
, dimens :: (Word, Word) -- Rows, Columns
...
}
We also need some more information. We need the current player location, an Observation
. We'll want to know the previous action, for rendering purposes. The game also stores the chance of slipping each turn. The last piece of state we want is the random generator. Storing this within our environment lets us write our step
function in a pure way, without IO
.
data FrozenLakeEnvironment = FrozenLakeEnvironment
{ grid :: Array Word TileType
, dimens :: (Word, Word) -- Rows, Cols
, currentObservation :: Observation
, previousAction :: Maybe Action
, slipChance :: Double
, randomGenerator :: Rand.StdGen
}
API Functions
Now our environment needs its API functions. We had three main ones last time. These were reset
, render
, and step
. Last week we wrote these to take the environment as an explicit parameter. But this time, we'll write them in the State
monad. This will make it much easier to chain these actions together later. Let's start with reset
, the simplest function. All it does is set the observation as 0 and remove any previous action.
resetEnv :: (Monad m) => StateT FrozenLakeEnvironment m Observation
resetEnv = do
let initialObservation = 0
fle <- get
put $ fle { currentObservation = initialObservation
, previousAction = Nothing }
return initialObservation
Rendering is a bit more complicated. When resetting, we can use any underlying monad. But to render, we'll insist that the monad allows IO
, so we can print to console. First, we get our environment and pull some key values out of it. We want the current observation and each row of the grid.
renderEnv :: (MonadIO m) => StateT FrozenLakeEnvironment m ()
renderEnv = do
fle <- get
let currentObs = currentObservation fle
elements = A.assocs (grid fle)
numCols = fromIntegral . snd . dimens $ fle
rows = chunksOf numCols elements
...
We use chunksOf
with the number of columns to divide our grid into rows. Each element of each row-list is the pairing of the "index" with the tile type. We keep the index so we can compare it to the current observation. Now we'll write a helper to render each of these rows. We'll have another helper to print a character for each tile type. But we'll print X
for the current location.
renderEnv :: (MonadIO m) => StateT FrozenLakeEnvironment m ()
renderEnv = do
...
where
renderRow currentObs row = do
forM_ row (\(idx, t) -> liftIO $ if idx == currentObs
then liftIO $ putChar 'X'
else liftIO $ putChar (tileToChar t))
putChar '\n'
tileToChar :: TileType -> Char
...
Then we just need to print a line for the previous action, and render each row:
renderEnv :: (MonadIO m) => StateT FrozenLakeEnvironment m ()
renderEnv = do
fle <- get
let currentObs = currentObservation fle
elements = A.assocs (grid fle)
numCols = fromIntegral . snd . dimens $ fle
rows = chunksOf numCols elements
liftIO $ do
putStrLn $ case (previousAction fle) of
Nothing -> ""
Just a -> " " ++ show a
forM_ rows (renderRow currentObs)
where
renderRow = ...
Stepping
Now let's see how we update our environment! This will also be in our State
monad (without any IO constraint). It will return a 3-tuple with our new observation, a "reward", and a boolean for if we finished. Once again we start by gathering some useful values.
stepEnv :: (Monad m) => Action
-> StateT FrozenLakeEnvironment m (Observation, Double, Bool)
stepEnv act = do
fle <- get
let currentObs = currentObservation fle
let (slipRoll, gen') = Rand.randomR (0.0, 1.0) (randomGenerator fle)
let allLegalMoves = legalMoves currentObs (dimens fle)
let (randomMoveIndex, finalGen) =
randomR (0, length AllLegalMoves - 1) gen'
...
-- Get all the actions we can do, given the current observation
-- and the number of rows and columns
legalMoves :: Observation -> (Word, Word) -> [Action]
...
We now have two random values. The first is for our "slip roll". We can compare this with the game's slipChance
to determine if we try the player's move or a random move. If we need to do a random move, we'll use randomMoveIndex
to figure out which random move we'll do.
The only other check we need to make is if the player's move is "legal". If it's not we'll stand still. The applyMoveUnbounded
function tells us what the next Observation
should be for the move. For example, we add 1 for moving right, or subtract 1 for moving left.
stepEnv :: (Monad m) => Action
-> StateT FrozenLakeEnvironment m (Observation, Double, Bool)
stepEnv act = do
...
let newObservation = if slipRoll >= slipChance fle
then if act `elem` allLegalMoves
then applyMoveUnbounded
act currentObs (snd . dimens $ fle)
else currentObs
else applyMoveUnbounded
(allLegalMoves !! nextIndex)
currentObs
(snd . dimens $ fle)
...
applyMoveUnbounded ::
Action -> Observation -> Word -> Observation
...
To wrap things up we have to figure out the consequences of this move. If it lands us on the goal tile, we're done and we get a reward! If we hit a hole, the game is over but our reward is 0. Otherwise there's no reward and the game isn't over. We put
all our new state data into our environment and return the necessary values.
stepEnv :: (Monad m) => Action
-> StateT FrozenLakeEnvironment m (Observation, Double, Bool)
stepEnv act = do
...
let (done, reward) = case (grid fle) A.! newObservation of
Goal -> (True, 1.0)
Hole -> (True, 0.0)
_ -> (False, 0.0)
put $ fle { currentObservation = newObservation
, randomGenerator = finalGen
, previousAction = Just act }
return (newObservation, reward, done)
Playing the Game
One last step! We want to be able to play our game by creating a gameLoop
. The final result of our loop will be the last observation and the game's reward. As an argument, we'll pass an expression that can generate an action. We'll give two options. One for reading a line from the user, and another for selecting randomly. Notice the use of toEnum
, so we're entering numbers 0-3.
gameLoop :: (MonadIO m) =>
StateT FrozenLakeEnvironment m Action ->
StateT FrozenLakeEnvironment m (Observation, Double)
gameLoop chooseAction = do
...
chooseActionUser :: (MonadIO m) => m Action
chooseActionUser = (toEnum . read) <$> (liftIO getLine)
chooseActionRandom :: (MonadIO m) => m Action
chooseActionRandom = toEnum <$> liftIO (Rand.randomRIO (0, 3))
Within each stage of the loop, we render the environment, generate a new action, and step the game. Then if we're done, we return the results. Otherwise, recurse. The power of the state monad makes this function quite simple!
gameLoop :: (MonadIO m) =>
StateT FrozenLakeEnvironment m Action ->
StateT FrozenLakeEnvironment m (Observation, Double)
gameLoop chooseAction = do
renderEnv
newAction <- chooseAction
(newObs, reward, done) <- stepEnv newAction
if done
then do
liftIO $ print reward
liftIO $ putStrLn "Episode Finished"
renderEnv
return (newObs, reward)
else gameLoop chooseAction
And now to play our game, we start with a simple environment and execute our loop!
basicEnv :: IO FrozenLakeEnvironment
basicEnv = do
gen <- Rand.getStdGen
return $ FrozenLakeEnvironment
{ currentObservation = 0
, grid = A.listArray (0, 15) (charToTile <$> "SFFFFHFHFFFHHFFG")
, slipChance = 0.0
, randomGenerator = gen
, previousAction = Nothing
, dimens = (4, 4)
}
playGame :: IO ()
playGame = do
env <- basicEnv
void $ execStateT (gameLoop chooseActionUser) env
Conclusion
This example illustrates two main lessons. First, the state monad is very powerful for managing any type of game situation. Second, defining our API makes implementation straightforward. Next week, we'll explore another toy example with a different state space. This will lead us on the path to generalizing our data structure.
Remember, if you need any more details about these code samples, take a look at the full code on Github! You should also subscribe to Monday Morning Haskell! You'll get our monthly newsletter and access to our subscriber resources!
Open AI Primer: Frozen Lake!
Last year, we spent quite a bit of time on this blog creating a game using the Gloss library. This process culminated in trying to use machine learning to train an agent to play our Maze Game well. The results were not particularly successful. But I've always wanted to come back to the idea of reinforcement learning for game agents.
The Open AI Gym is an open source project for teaching the basics of reinforcement learning. It provides a framework for understanding how we can make agents that evolve and learn. It's written in Python, so this first article will be mostly in Python. But we can (and will) try to implement many of the ideas in Haskell. This week, we'll start exploring some of the core concepts. We'll examine what exactly an "environment" is and how we can generalize the concept. In time, we'll also see how Gloss can help us.
We'll ultimately use machine learning to train our agents. So you'll want some guidance on how to do that in Haskell. Read our Machine Learning Series and download our Tensor Flow guide to learn more!
Frozen Lake
To start out our discussion of AI and games, let's go over the basic rules of one of the simplest examples, Frozen Lake. In this game, our agent controls a character that is moving on a 2D "frozen lake", trying to reach a goal square. Aside from the start square ("S") and the goal zone ("G"), each square is either a frozen tile ("F") or a hole in the lake ("H"). We want to avoid the holes, moving only on the frozen tiles. Here's a sample layout:
SFFF
FHFH
FFFH
HFFG
So a safe path would be to move down twice, move right twice, down again, and then right again. What complicates the matter is that tiles can be "slippery". So each turn, there's a chance we won't complete our move, and will instead move to a random neighboring tile.
Playing the Game
Now let's see what it looks like for us to actually play the game using the normal Python code. This will get us familiar with the main ideas of an environment. We start by "making" the environment and setting up a loop where the user can enter their input move each turn:
import gym
env = gym.make('FrozenLake-v0')
env.reset()
while True:
move = input("Please enter a move:")
...
There are several functions we can call on the environment to see it in action. First, we'll render
it, even before making our move. This lets us see what is going on in our console. Then we have to step
the environment using our move. The step
function makes our move and provides us with 4 outputs. The primary ones we're concerned with are the "done" value and the "reward". These will tell us if the game is over, and if we won.
while True:
env.render()
move = input("Please enter a move:")
action = int(move)
observation, reward, done, info = env.step(action)
if done:
print(reward)
print("Episode finished")
env.render()
break
We use numbers in our moves, which our program converts into the input space for the game. (0 = Left, 1 = Down, 2 = Right, 3 = Up).
We can also play the game automatically, for several iterations. We'll select random moves by using action_space.sample()
. We'll discuss what the action space is in the next part. We can also use reset
on our environment at the end of each iteration to return the game to its initial state.
for i in range(20):
observation = env.reset()
for t in range(100):
env.render()
print(observation)
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
if done:
print("Episode finished after {} timesteps".format(t + 1))
break
env.close()
These are the basics of the game. Let's go over some of the details of how an environment works, so we can start imagining how it will work in Haskell.
Observation and Action Spaces
The first thing to understand about environments is that each environment has an "observation" space and an "action" space. The observation space gives us a numerical representation of the state of the game. This doesn't include the actual layout of our board, just the mutable state. For our frozen lake example, this is only the player's current position. We could use two numbers for the player's row and column. But in fact we use a single number, the row number multiplied by the column number.
Here's an example where we print the observation after moving right twice, and then down. We have to call reset
before using an environment. Then calling this function gives us an observation we can print. Then, after each step, the first return value is the new observation.
import gym
env = gym.make('FrozenLake-v0')
o = env.reset()
print(o)
o, _, _, _ = env.step(2)
print(o)
o, _, _, _ = env.step(2)
print(o)
o, _, _, _ = env.step(1)
print(o)
# Console output
0
1
2
6
So, with a 4x4 grid, we start out at position 0. Then moving right increases our position index by 1, and moving down increases it by 4.
This particular environment uses a "discrete" environment space of size 16. So the state of the game is just a number from 0 to 15, indicating where our agent is. More complicated games will naturally have more complicated state spaces.
The "action space" is also discrete. We have four possible moves, so our different actions are the integers from 0 to 3.
import gym
env = gym.make('FrozenLake-v0')
print(env.observation_space)
print(env.action_space)
# Console Output
Discrete(16)
Discrete(4)
The observation space and the action space are important features of our game. They dictate the inputs and outputs of the each game move. On each turn, we take a particular observation as input, and produce an action as output. If we can do this in a numerical way, then we'll ultimately be able to machine-learn the program.
Towards Haskell
Now we can start thinking about how to represent an environment in Haskell. Let's think about the key functions and attributes we used when playing the game.
- Observation space
- Action space
- Reset
- Step
- Render
How would we represent these in Haskell? To start, we can make a type for the different numeric spaces can have. For now we'll provide a discrete space option and a continuous space option.
data NumericSpace =
Discrete Int |
Continuous Float
Now we can make an Environment
type with fields for these spaces. We'll give it parameters for the observation type and the action type.
data Environment obs act = Environment
{ observationSpace :: NumericSpace
, actionSpace :: NumericSpace
...
}
We don't know yet all the rest of the data our environment will hold. But we can start thinking about certain functions for it. Resetting will take our environment and return a new environment and an observation. Rendering will be an IO action.
resetEnv :: Environment obs act -> (obs, Environment obs act)
renderEnv :: Environment obs act -> IO ()
The step
function is the most important. In Python, this returns a 4-tuple. We don't care about the 4th "info" element there yet. But we do care to return our environment type itself, since we're in a functional language. So we'll return a different kind of 4-tuple.
stepEnv :: Environment obs act -> act
-> (obs, Float, Bool, Environment obs act)
It's also possible we'll use the state monad here instead, as that could be cleaner. Now this isn't the whole environment obviously! We'd need to store plenty of unique internal state. But what we see here is the start of a typeclass that we'll be able to generalize across different games. We'll see how this idea develops!
Conclusion
Hopefully you've got a basic idea now of what makes up an environment we can run. Next time, we'll push a bit further with our Haskell and implement Frozen Lake there!