James Bowen James Bowen

General Functions with Typeclasses

Last week, we looked at the basics of Haskell’s data types. We saw that haskell is not an object oriented language, and we don’t have inheritance between data types. This would get very confusing with all the different constructors that a data type can have. Haskell gives a lot of the same functionality as inheritance by using Typeclasses. This week we’ll take a quick look at this concept.

What is a Typeclass?

A typeclass encapsulates functionality that is common to different types. In practice, a typeclass describes a series of functions that you expect to exist for a given type. When these functions exist, you can create what is called an “instance” of a typeclass.

Typeclasses are a lot like interfaces in Java. You specify a group of functions, but only with the type signatures. Then for each relevant type, you'll need to specify an implementation for each function. As an example, suppose we had two different types referring to different kinds of people.

data Student = Student String Int

data Teacher = Teacher
  { teacherName:: String
  , teacherAge:: Int
  , teacherDepartment :: String
  , teacherSalary :: Int
  }

We could then make a typeclass called IsPerson. We'll give it a couple functions that refer to the name and age of the person. Then we parameterize the class by the single type a. We'll use that type parameter in the type signatures of the functions:

class IsPerson a where
  personName :: a -> String
  personAge :: a -> Int

Creating Instances of Typeclasses

Now let's create an instance of the typeclass. All we have to do is implement each function under the instance keyword:

instance IsPerson Student where
  personName (Student name _) = name
  personAge (Student _ age) = age

instance IsPerson Teacher where
  personName = teacherName
  personAge = teacherAge

There are a lot of simple typeclasses in the base libraries that you’ll need to know for some basic tasks. For instance, to compare two items for equality, you’ll need the Eq typeclass:

class Eq a where
  (==) :: a -> a -> Bool
  (/=) :: a -> a -> Bool

We can define instances for these for all our types. But for simple, base library classes like these, GHC can define them for us! All we need is the deriving keyword:

data Student = Student String Int
  deriving (Eq)

data Teacher = Teacher
  { teacherName:: String
  , teacherAge:: Int
  , teacherDepartment :: String
  , teacherSalary :: Int
  }
  deriving (Eq)

Using Typeclass Constraints

But why are typeclasses important? Well, often times we want to write code that is as general as possible. We want to write functions that assume as little about their inputs as they can. For instance, suppose we have this function that will print a teacher’s name:

printName :: Teacher -> IO ()
printName teacher = putStrLn $ personName teacher

We can use this function for more types than Teacher though! Any type that implements IsPerson will do. So we can make the function polymorphic, and add the IsPerson constraint on our a type:

printName :: (IsPerson a) => a-> IO ()
printName person = putStrLn $ personName person

You can also use typeclasses to constrain a type parameter of a new data type.

data (IsPerson a) => EmployeeRecord a = EmployeeRecord
  { employee :: a
  , employeeTenure :: Int
  }

Typeclasses can even provide a form of inheritance. You can constrain a typeclass by another typeclass! A couple base classes show an example of this. The “Orderable” typeclass Ord depends on the type having an instance of Eq:

class (Eq a) => Ord a where
  compare :: a -> a -> Ordering
  (<=) :: a -> a -> Bool
  ...

Conclusion

Haskell programmers like code that is as general as possible. Object oriented languages try to accomplish this with inheritance. But Haskell gets most of the same functionality with typeclasses instead. They describe common features between types, and provide a lot of flexibility.

To continue learning more about the Haskell basics, take a look at our Getting Started Checklist and get going!

Do you already understand the basics and want more of a challenge? Check out our Recursion Workbook!

Read More
James Bowen James Bowen

Haskell Data Types in 5 Steps

People often speak of a dichotomy between “object oriented” programming and “functional” programming. Haskell falls into the latter category, meaning we do more of our work with functions. We don't use hierarchies of objects to abstract work away. But Haskell is also heavily driven by its type system. So of course we still define our own data types in Haskell! Even better, Haskell has unique mechanisms you won't find in OO languages!

The Data Keyword and Constructors

In general, we define a new data type by using the data keyword, followed by the name of the type we’re defining. The type has to begin with a capital letter to distinguish it from normal expression names.

data Employee = ...

To start defining our type, we must provide a constructor. This is another capitalized word that allows you to create expressions of your new type. The constructor name is then followed by a list of 0 or more other types. These are like the “fields” that a data type carries in a language like Java or C++.

data Employee = Executive String Int Int

employee1 :: Employee
employee1 = Executive "Jane Doe" 38 300000

Sum Types

In a language like Java, you can have multiple constructors for a type. But the type will still encapsulate the same data no matter what constructor you use. In Haskell, you can have many constructors for your data type, separated by a vertical bar |. Each of your constructors then has its own list of data types! So different constructors of the same type can have different underlying data! We refer to a type with multiple constructors as a “sum” type.

data Employee =
 Executive String Int Int |
 VicePresident String String Int |
 Manager String String |
 Engineer String Int

If your type has only one constructor, it is not uncommon to re-use the name of the type as the constructor name:

data Employee = Employee String Int

Record Syntax

You can also define a type using “record syntax”. This allows you to provide field names to each type in the constructor. With these, you access the individual fields with simple functions. Otherwise, you'll need to resort to pattern matching. This is more commonly seen with types that use a single constructor. It is a good practice to prefix your field names with the type name to avoid name conflicts.

data Employee = Employee
 { employeeName :: String
 , employeeAge :: Int
 }

printName :: Employee -> IO ()
printName employee = putStrLn $ employeeName employee

Type Synonyms

As in C++, you can create type synonyms, providing a second name for a type. Sometimes, expressions can mean different things, even though they have the same representation. Type synonyms can help keep these straight. To make a synonym, use the type keyword, the new name you would like to use to refer to your type, and then the original type.

type InterestRate = Float
type BankBalance = Float

applyInterest :: BankBalance -> InterestRate -> BankBalance
applyInterest balance interestRate = balance + (balance * interestRate)

Note though that type synonyms have no impact on how your code compiles! This means it is still quite possible to misuse them! The following type signatures will still compile for this function:

applyInterest :: Float -> Float -> Float
applyInterest :: InterestRate -> BankBalance -> Float

Newtypes

To avoid the confusion that can occur above, you can use the newtype keyword. A newtype is like a cross between data and type. Like type, you’re essentially renaming a type. But you do this by writing a declaration that has exactly one constructor with exactly one type. As with a data declaration, you can use record syntax within newtypes.

newtype BankBalance = BankBalance Float
newtype InterestRate = InterestRate { unInterestRate :: Float }

Once you’ve done this, you will have to use the constructors (or record functions) to wrap and unwrap your code:

applyInterest :: BankBalance -> InterestRate -> BankBalance
applyInterest (BankBalance bal) rate = BankBalance $
 bal + (unInterestRate rate * bal)

Newtype declarations do affect how your code compiles. So the following invalid type signature will NOT compile!

applyInterest :: InterestRate -> BankBalance -> Float
applyInterest (BankBalance bal) (InterestRate rate) = ...

Conclusion

As we learned a couple weeks ago, types are important in Haskell. So it’s not surprising that Haskell has some nifty constructs for building our own types. Constructors and sum types give us the flexibility to choose what kind of data we want to store. We can even change the data stored for different elements of the same type! Type synonyms and newtypes give us two different ways to rename our types. The first is easy and helps avoid confusion. The second requires more code re-writing, but provides more type safety.

If you’ve never written a line of Haskell before, never fear! Take a look at our Getting Started Checklist to get going!

Read More
James Bowen James Bowen

Syntactic Implications of Expressions

Last week we explored expressions and types, the fundamental building blocks of Haskell. These concepts revealed some major structural differences between Haskell and procedural languages. This week we’ll consider the implications of these ideas. We'll see how they affect syntactic constructs in Haskell.

If Statements

Consider this Java function:

public int func(int input) {
  int z = 5;
  if (input % 2 == 0) {
    z = 4;
  }
  return z * input;
}

Here we see a very basic if-statement. Under a certain condition, we change the value of z. If the condition isn’t true we do nothing. But this is not how if statements work in Haskell! What lies inside the if statement is a command, and we compose our code with expressions and not commands!

In Haskell, an if-statement is an expression like anything else. This means it has to have a type! This constrains us in a couple ways. If the condition is true, we can supply an expression that will be the result. Then the type of the whole statement has the type of this expression. But what if the condition is not true? Can an expression be null or void? Most of the time no!* The following is rather problematic:

myValue :: Int -> ??? -- What type would this have if the condition is false?
myValue x = if x `mod` 2 == 0 then 5

This means that if-statements in Haskell must have an else branch. Furthermore, the else branch must have an expression that has the same type as in the true branch!

myValue :: Int -> Int -- Both the false and true branches are Int
myValue x = if x `mod` 2 == 0 then 5 else 3

Notice the real difference here. We’re used to saying “if x, do y”. But in Haskell, we assign an expression to be some value that may differ depending on a condition. So from a conceptual standpoint, the “if” is further inside the statement. This is one big hurdle to cross when first learning Haskell.

*Note: In monadic circumstances, a “null” value (represented as the unit type) can make sense. See when and unless.

Where Statements

We've established that everything in Haskell is an expression. But we often use commands to assign values to intermediate variables. How can we do this in Haskell? After all, if a computation is complicated, we don’t want to describe it all on one line.

The where statement fills this purpose. We can use where, followed by any number of statements under to assign names to intermediate values!

mathFunc :: Int -> Int -> Int -> Int
mathFunc a b c = sum + product + difference
  where
    sum = a + b + c
    product = a * b * c
    difference = c - b - a

Statements in a where clause can be in any order. They can depend on each other, as long as there are no dependency loops!

mathFunc :: Int -> Int -> Int -> Int
mathFunc a b c = sum + product + difference
  where
    product = a * b * c * sum
    sum = a + b + c
    difference = sum - b - product

Let Statements

There’s a second way of describing intermediate variables! We can also use a let statement, combined with in. Unlike with where, let bindings must be in the right order. They can't depend on later values. Here’s the above function, written using let:

mathFunc :: Int -> Int -> Int -> Int
mathFunc a b c =
  let sum = a + b + c
       product = a * b * c
       difference = c - b - a
  in sum + product + difference

Why do we have two ways of expressing the same concept? Well think about writing an essay. You'll often have terms you want to define for the reader. Sometimes, it makes more sense to define these before you use them. This is like using let. But sometimes, you can use the expression first, and show the details of how you calculated it later. This is what where is for! This especially works when the expression name is descriptive.

As a side note, there are also situations with monads where you can’t use where and have to use let. But that’s a topic for another time!

Conclusion

If you’re new Haskell, hopefully these short articles are giving you some quick insights into the language. For more details, take a look at our Getting Started Checklist! If you think you’ve mastered the basics and want to learn how to organize a real project, you should take our free Stack mini-course. It will teach you how to use the Stack tool to keep your Haskell code straight.

Read More
James Bowen James Bowen

Back to Basics: Expressions and Types, Distilled

This week we begin our “Haskell Distilled” series. These articles will look at important Haskell concepts and break them down into the most fundamental parts. This week we start with the essence of what defines Haskell as a language.

Everything is an Expression

In Haskell, almost everything we write is an expression. This goes from the simplest primitive elements to the most complicated functions. Each of the following is an expression:

True
False
4
9.8
[6, 2]
\a -> a + 5
‘a’ 
“Hello”

This is, in part, what defines Haskell as a functional language. We can compare this against a more procedural language like java:

public int add5ToInput(int x) {
  int result = x  + 5;
  return 5;
}

The method we’ve defined here isn’t a expression in the same way a function is in Haskell. We can’t substitute the definition of this method in for an invocation of it. In Haskell, we can do exactly this with functions! Each of the lines in this method are not expressions themselves either. Rather, each of them is a command to execute.

Even Haskell code that appears to be procedural in nature is really an expression!

main = do
  putStrLn “Hello”
  putStrLn “World”

In this case, main is an expression of type IO (), not a list of procedures. Of course, the IO monad functions in a procedural manner sometimes. But still, the substitution property holds up.

Every Expression has a Type

Once we understand expressions, the next building block is types. All expressions have types. For instance, we can write out potential types for each of the expressions we had above:

True :: Bool
False :: Bool
4 :: Int
9.8 :: Float
[6, 2] :: [Int]
\a -> a + 5 :: Int -> Int
‘a’ :: Char
“Hello” :: String

Types tell us what operations we can perform with the data we have. They ensure we don’t try to perform operations with non-sensical data. For instance we can only use the add operation (+) on expressions that are of the same type and that are numeric:

(+) :: (Num a) -> a -> a -> a

By contrast, in javascript, where we can "add" arbitrary variables (think [] + {}) which can have non-sensical results.

Functions are Still Just Expressions

As we saw in the Java example above, methods are a little cumbersome. There is a distinction between a method definition and the commands we use for normal code. But in Haskell, we define our more complicated expressions as functions. And these are still expressions just like our most primitive definitions! Functions have types as well, based on the parameters they take. So we could have functions that take one or more parameters:

add5 :: Int -> Int
add5 x = x + 5

sumAndProduct :: Int -> Int -> Int
sumAndProduct x y = (x + y) * (x + y)

In our Java example, the statement return result is fundamentally different from the method definition of add5ToInput. But above, sumAndProduct isn’t a different concept from the simpler expression 5.

We use functions by “applying” them to arguments like so:

sumAndProduct 7 8 :: Int

We can also partially apply functions. This means we fix one of the arguments, but leave the rest of it open-ended.

sumAndProduct 7 :: Int -> Int

The result of this expression is a new function that takes only a single argument. It’s result will be the sumAndProduct function with the first argument fixed as 7. We've peeled off one of the arguments from the function type.

Conclusion

In the coming week’s we’ll continue breaking down simple concepts. We’ll examine how some of Haskell’s other syntactic constructs work within this system of expressions and types. If you’re new to Haskell, you should check out our Getting Started Checklist so you can start using these concepts!

Read More
James Bowen James Bowen

1 Year of Haskell

This week marks the one year anniversary of Monday Morning Haskell! I’ve written an awful lot in the past year. It’s now obvious that the “blog” format doesn’t do justice to the amount of content on the site. There’s no organization, so a lot of the most useful content is stuck way down in the archives. So I’m taking this opportunity to announce some plans to reorganize the website!

Website Organization

I have two main focuses when it comes to Haskell. First, making the language easier for beginners to pick up. Second, showing the many industrial-strength tasks we can perform with Haskell. There will be two different sections of the website highlighting these components.

There will be a beginner’s corner that will focus on giving advice for starting the language from the ground up. It will feature content about the fundamental concepts and techniques of the language. It will also showcase some advice for conquering the psychological hurdles of Haskell. There will also be a section devoted to articles on production libraries, such as the recent series on building a Web API and using Tensor Flow.

Future Posts

This means I'll be focusing on organizing permanent content for a while. As a result, I won’t have as much time to spend on new blog posts. But I will still be publishing something new every Monday morning! These new posts will focus on distilling concepts into the most important parts. They'll look back at earlier lessons and pick out the highlights. Of course, I still have plans for some more in-depth technical tutorials. I’ll be looking at areas such as front-end web development, creating EDSLs, the C foreign function interface, and more.

Academy

I also am working hard on some new course material to go with our current free mini-course on learning Stack! I'm currently working on a full-length courses beginners course to help people learn Haskell from scratch. I will follow that up with another course aimed at using Haskell in production. This second course will go through both generally important production skills and showcase Haskell libraries applying those skills.

Conclusion

Altogether, the fun is just beginning at Monday Morning Haskell! I’m committed to continuing to build this website to be a resource for beginners and experts alike!

If you’re a beginner yourself, check out any of our newbie resources, like our Getting Started Checklist, our Recursion workbook, or the Stack mini-course mentioned above!

Read More
James Bowen James Bowen

The Right Types of Assumptions

A little while back I was discussing my Haskell and AI article with some other engineers I know. These people are very smart and come from diverse backgrounds of software engineering. But they tend to prefer dynamically typed languages like Javascript and Ruby. So they don’t buy into Haskell, which makes them no different from most of the engineering community.

Still, some things about the discussion surprised me. I heard, of course, the oft-repeated mantra that types do not replace tests. But even more surprising was one claim that types do not, and are not even meant to provide any kind of safety. One engineer even claimed he thinks of types only as directives to help a compiler optimize a program.

Now, it would be bold to claim that Haskell is better for all things and all circumstances. But I don’t think I’ve ever done that on this blog. It would also be unrealistic to claim that Haskell’s type system can replace all testing. And in fact I’ve dismissed that claim a couple times on this blog.

But surely Haskell’s type system gets us something. Surely, compiled types add some layer of extra safety and correctness to our programs. I’ve never thought this to be a particularly controversial point. Any yet many smart people dismiss even this claim. So these conversations got me thinking, why is it that I think Haskell's type system is useful? Why do I believe what I do?

Convincing Myself

For me, the evidence is a matter of personal experience. And a lot of other Haskellers would agree. I programmed in dynamic languages, such as Python and Javascript. I’ve also dealt a lot with Objective C and Java, which have static types but do not restrict effects as Haskell does.

It’s very clear to me where I get fewer bugs and write cleaner code that is more often correct, and that’s with Haskell. But of course I can’t base everything on my own experience, which could be subject to many biases. For instance, it could be that I haven’t spent enough time writing in dynamic languages at an industry level. It could also be that I happen to have become a more competent programmer in the last year. Perhaps I would have seen similar improvements no matter what language I focused on. Still, Haskell changed my programming in ways that other languages wouldn't have. I feel I could now go back to other languages and write better programs than I could before I learned Haskell.

Regardless, I would not feel as confident. Without Haskell's guarantees, there’s a great deal more mental overhead. You have to always ask yourself questions like, is every error condition checked? Are all my values non-null? Have I written enough tests to cover basic data format cases? Have I prevented effects like slow-running DB operations in performance critical places? Determining the answers to these might require a line-by-line audit of the code. Thorough code review is good. But when it becomes a pain, corners will get cut.

Perhaps if I had a couple more years experience writing Javascript, I wouldn’t view these tasks as such a burden. But one of the best answers I ever read on Quora gave this piece of advice. As you become a better programmer, you don’t get better at keeping things in your head and knowing the whole system. You get better at sectioning off the system in a way that you don’t have to keep it in your head. Haskell lends itself very well to this approach.

So it’s true that other languages can be written in a safe way. But Haskell forces you to write and think safely. In this regard, Haskell would actually be a fantastic first language.

Empirical Evidence?

Now there’s been some efforts to study the effects of programming language choice. This dissertation from UC Davis examined many open-source code bases to find what factors led to fewer bugs. The conclusion, was that strongly typed, functional languages fare better. A much older paper from Yale found that Haskell was a clear winner when it came to code clarity. (Python and Javascript weren’t covered in the study though).

Case closed right?

Well of course not. With the UC Davis paper, many possible issues prevent firm conclusions. After all, it only studied open source code and not industrial code. Confounding factors might well exist. Functional projects could just have fewer users. This could explain fewer bug reports. Also, code size dwarfed any effects of language choice. More code = more bugs.

But it’s hard to see how we could settle this question with an empirical study. No one has access to programmers with arbitrary levels of experience in a language. The Yale paper might have had the right idea. But it would need a larger scale to draw more solid conclusions.

What Other Evidence is out there?

So failing this kind of assessment, what is the best way to provide evidence in Haskell’s favor? It’s rather hard to look past testimony here. It would be great to see more publicity around stories like at Channable. Engineers there were able to use convert a small piece of infrastructure to Haskell. The result was a system that had fewer bugs that they could refactor with confidence.

Of course, testimony like this is still imperfect. There could be major survivorship bias. For every person who posts about Haskell solving their problems, there may be 5 who found it didn't help.

There’s also the fact of the world that many people don’t like change. You could provide all the evidence in the world in favor of a particular language. Some people will still choose to fall back on what they know. For people who have “tried and true” business models, this makes sense. Anything else seems risky.

So how do we Convince People?

I’ll repeat what I’ve noted a couple times in previous articles. Network effects are currently a huge drag on Haskell. The more Haskell developers who are out there, the easier it will be to convince companies to give it a try. Even if it is only on small projects, there will be more chances to see it succeed.

What is the best way to go about solving these network issues? For one, we need more marketing/Haskell evangelism. That said, there is still a sense that we need more educational material at all levels. Helping beginners get into the language is a great start. But there's definitely another hurdle to go from hobbyist to industry level.

Conclusion

Every once in awhile it’s important to take a step back and consider our assumptions. I definitely found it worthwhile to reexamine some of my beliefs about Haskell’s type system. It's helped me to remember why I think the way I do. There’s a good amount of evidence out there for Haskell’s utility and safety as a language. But the burden is on us as a community to collect those stories and put them out there more.

If you haven’t tried Haskell before, hopefully I’ve convinced you to be one of those people! Check out our Getting Started Checklist for some good resources in starting your Haskell education.

If you’ve dabbled a bit and want to make the next step up, we also have a few other resources for you! Check out of Stack mini-course, our Recursion Workbook or our Haskell Tensor Flow guide!

Read More
James Bowen James Bowen

Eff to the Rescue!

In the last couple weeks, we’ve seen quite a flurry of typeclasses. We used MonadDatabase and MonadCache to abstract the different effects within our API. This brought with it some benefits, but also some drawbacks. With these concepts abstracted, we could distill the API code into its simpler tasks. We didn't need to worry about connection configurations or lifting through different monads.

As we’ve seen though, there was a lot of boilerplate code involved. And there would be more if we wanted the freedom to have different parts of our app use different monad stacks. Free Monads are one solution to this problem. They allow you to compose your program so that the order in which you specify your effects does not matter. We’ll still have to write “interpretations” as we did before. But now they’ll be a lot more composable.

You can follow along the code for this by checking out the effects-3 branch on Github. Also, I do have to give a shoutout to Sandy Maguire for his talk on Eff and Free monads from BayHac. Most of what I know about free monads comes from that talk. You should also check out his blog as well.

Typeclass Boilerplate

Let’s review the main drawback of our type class approach. Recall our original definition of the AppMonad, and some of the instances we had to write for it:

newtype AppMonad a = AppMonad (ReaderT RedisInfo (SqlPersistT (LoggingT IO)) a)
  deriving (Functor, Applicative, Monad)

instance MonadDatabase AppMonad where
  fetchUserDB = liftSqlPersistT . fetchUserDB
  createUserDB = liftSqlPersistT . createUserDB
  deleteUserDB = liftSqlPersistT . deleteUserDB
  fetchArticleDB = liftSqlPersistT . fetchArticleDB
  createArticleDB = liftSqlPersistT . createArticleDB
  deleteArticleDB = liftSqlPersistT . deleteArticleDB
  fetchArticlesByAuthor = liftSqlPersistT . fetchArticlesByAuthor
  fetchRecentArticles = liftSqlPersistT fetchRecentArticles

liftSqlPersistT :: SqlPersistT (LoggingT IO) a -> AppMonad a
liftSqlPersistT action = AppMonad $ ReaderT (const action)

instance (MonadIO m, MonadLogger m) => MonadDatabase (SqlPersistT m) where
  ...

But suppose another part of our application wants to use a different monad stack. Perhaps it uses different configuration information and tracks different state. But it still needs to be able to connect to the database. As a result, we’ll need to write more instances. Each of these will need a new definition for all the different effect functions. Most all these will be repetitive and involve some combination of lifts. This isn’t great. Further, suppose we want arbitrary reordering of the monad stack. The number of instances you’ll have to write scales quadratically. Once you get to six or seven layers, this is a real pain.

Main Ideas of Eff

We can get much better composability by using free monads. I’m not going to get into the conceptual details of free monads. Instead I’ll show how to implement them using the Eff monad from the Freer Effects library. Let's first think back to how we define constraints on monads in our handler functions.

fetchUsersHandler :: (MonadDatabase m, MonadCache m) => Int64 -> m User

We take some generic monad, and constrain it to implement our type classes. With Eff, we’ll specify constraints in a different way. We have only one monad, the Eff monad. This monad is parameterized by a type-level list of other monads that it carries on its stack.

type JustIO a = Eff ‘[IO] a

type ReaderWithIO a = Eff ‘[Reader RedisInfo, IO] a

With this in mind, we can specify constraints on what monads are part of our stack using Member. Here’s how we’ll re-write the type signature from above:

fetchUsersHandler :: (Member Database r, Member Cache r) => Int64 -> Eff r User

We’ll specify exactly what Database and Cache are in the next section. But in essence, we’re stating that we have these two kinds of effects that live somewhere on our monad stack r. It doesn’t matter what order they’re in! This gives us a lot of flexibility. But before we see why, let’s examine how we actually write these effects.

Coding Up Effects

The first thing we’ll do is represent our effects as data types, rather than type classes. Starting with our database functionality, we’ll make a type Database a. This type will have one constructor for each function from our MonadDatabase typeclass. We’ll capitalize the names since they’re constructors instead of functions names. Then we’ll use GADT syntax, so that the result will be of type Database instead of a function in a particular monad. To start, here’s what our FetchUserDB constructor looks like:

{-# LANGUAGE GADTs #-}

data Database a where
  FetchUserDB :: Int64 -> Database (Maybe User)
  ...

Our previous definition looked like Int64 -> m (Maybe User). But we’re now constructing a Database action. Here’s the rest of the definition:

data Database a where
  FetchUserDB :: Int64 -> Database (Maybe User)
  CreateUserDB :: User -> Database Int64
  DeleteUserDB :: Int64 -> Database ()
  FetchArticleDB :: Int64 -> Database (Maybe Article)
  CreateArticleDB :: Article -> Database Int64
  DeleteArticleDB :: Int64 -> Database ()
  FetchArticlesByAuthor :: Int64 -> Database [KeyVal Article]
  FetchRecentArticles :: Database [(KeyVal User, KeyVal Article)]

Now we can also do the same thing with a Cache type instead of our MonadCache class:

data Cache a where
  CacheUser :: Int64 -> User -> Cache ()
  FetchCachedUser :: Int64 -> Cache (Maybe User)
  DeleteCachedUser :: Int64 -> Cache ()

Now, unfortunately, we do need some boilerplate with Eff. For each of constructor we create, we’ll need a function to run that item within the Eff monad. For these, we’ll use the send function from the Eff library. Each function states that our effect type is a member of our monad set. Then it will otherwise match the type of that constructor, only within the Eff monad. Here are the three examples for our Cache type.

cacheUser :: (Member Cache r) => Int64 -> User -> Eff r ()
cacheUser uid user = send $ CacheUser uid user

fetchCachedUser :: (Member Cache r) => Int64 -> Eff r (Maybe User)
fetchCachedUser = send . FetchCachedUser

deleteCachedUser :: (Member Cache r) => Int64 -> Eff r ()
deleteCachedUser = send . DeleteCachedUser

But wait! You might be asking, aren’t we trying to avoid boilerplate? Well, it’s hard to avoid all boilerplate. But the real gain we’ll get is that our boilerplate will scale in a linear fashion. We only need this code once per effect type we create. Remember, the alternative is quadratic growth.

Interpreting our Effects

To write "interpretations" of our effects in the type class system, we wrote instances. Here, we can do it with functions that we'll prefix with run. These will assume we have an action where our effect is on "top" of the monad stack. The result will be a new action with that layer peeled off.

runDatabase :: Eff (Database ': r) a -> Eff r a
runDatabase = ...

Now, we have to consider, what would be necessary to run our database effects? For our production application, we need to know that SqlPersistT lives in the monad stack. So we’ll add (SqlPersistT (LoggingT IO)) as a constraint on the rest of the r for our monad.

runDatabase :: (Member (SqlPersistT (LoggingT IO)) r) => Eff (Database ': r) a -> Eff r a

So we are in effect constraining the ordering of our monad, but doing it in a logical way. It wouldn’t make sense for us to ever run our database effects without knowing about the database itself.

To write this function, we specify a transformation between this Member of the rest of our stack and our Database type. We can run this transformation with runNat:

runDatabase :: (Member (SqlPersistT (LoggingT IO)) r) => Eff (Database ': r) a -> Eff r a
runDatabase = runNat databaseToSql
  where
    databaseToSql :: Database a -> SqlPersistT (LoggingT IO) a
    ...

Now we need a conversion between a Database object and a SqlPersistT action. For this, we plug in all the different function definitions we’ve been using all along. For instance, here’s what our fetchUserDB and createDB definitions look like:

databaseToSql (FetchUserDB uid) = get (toSqlKey uid)
databaseToSql (CreateUserDB user) = fromSqlKey <$> insert user

Our other constructors will follow this pattern as well.

Now, we’ll also want a way to interpret SqlPersistT effects within Eff. We’ll depend on only having IO as a deeper member within the stack here, though we also need the PGInfo parameter. Then we use runNat and convert between our SqlPersistT action and a normal IO action. We’ve done this before with runPGAction:

runSqlPersist :: (Member IO r) => PGInfo -> Eff ((SqlPersistT (LoggingT IO)) ': r) a -> Eff r a
runSqlPersist pgInfo = runNat $ runPGAction pgInfo

We go through this same process with Redis and our cache. To run a Redis action from our monad stack, we have to take the RedisInfo as a parameter and then also have IO on our stack:

runRedisAction :: (Member IO r) => RedisInfo -> Eff (Redis ': r) a -> Eff r a
runRedisAction redisInfo = runNat redisToIO
  where
    redisToIO :: Redis a -> IO a
    redisToIO action = do
      connection <- connect redisInfo
      runRedis connection action

Once we have this transformation, we can use the dependency on Redis to run Cache actions.

runCache :: (Member Redis r) => Eff (Cache ': r) a -> Eff r a
runCache = runNat cacheToRedis
  where
    cacheToRedis :: Cache a -> Redis a
    cacheToRedis (CacheUser uid user) = void $ setex (pack . show $ uid) 3600 (pack . show $ user)
    cacheToRedis (FetchCachedUser uid) = do
      result <- get (pack . show $ uid)
      case result of
        Right (Just userString) -> return $ Just (read . unpack $ userString)
        _ -> return Nothing
    cacheToRedis (DeleteCachedUser uid) = void $ del [pack . show $ uid]

And now we're done with our interpretations!

A Final Natural Transformation

Since we’re using Servant, we’ll still have to pick a final ordering. We need a natural transformation from Eff to Handler. Thus we'll specify a specific order so we have a specific Eff. We’ll put our cache effects on the top of our stack, then database operations, and finally, plain old IO.

transformEffToHandler ::
  PGInfo ->
  RedisInfo ->
  (Eff '[Cache, Redis, Database, SqlPersistT (LoggingT IO), IO]) :~> Handler

So how do we define this transformation? As always, we’ll want to create an IO action that exposes an Either value so we can catch errors. First, we can use our different run functions to peel off all the layers on our stack until all we have is IO:

transformEffToHandler ::
  PGInfo ->
  RedisInfo ->
  (Eff '[Cache, Redis, Database, SqlPersistT (LoggingT IO), IO]) :~> Handler
transformEffToHandler pgInfo redisInfo = NT $ \action -> do
  -- ioAct :: Err ‘[IO] a
  let ioAct = (runSqlPersist pgInfo . runDatabase . runRedisAction redisInfo . runCache) action
  ...

When we only have a single monad on our stack, we can use runM to get an action in that monad. So we need to apply that to our action, handle errors, and return the result as a Handler!

transformEffToHandler ::
  PGInfo ->
  RedisInfo -> 
  (Eff '[Cache, Redis, Database, SqlPersistT (LoggingT IO), IO]) :~> Handler
transformEffToHandler pgInfo redisInfo = NT $ \action -> do
  let ioAct = (runSqlPersist pgInfo . runDatabase . runRedisAction redisInfo . runCache) action
  result <- liftIO (runWithServantHandler (runM ioAct))
  Handler $ either throwError return result

And with that we’re done! Here’s the big win with Eff. It’s quite easy for us to write a different transformation on a different ordering of the Stack. We just change the order in which we apply our run functions!

-- Put Database on top instead of Cache
transformEffToHandler :: 
  PGInfo -> 
  RedisInfo -> 
  (Eff '[Database, SqlPersistT (LoggingT IO), Cache, Redis, IO]) :~> Handler
transformEffToHandler pgInfo redisInfo = NT $ \action -> do
  let ioAct = (runRedisAction redisInfo . runCache . runSqlPersist pgInfo . runDatabase) action
  result <- liftIO (runWithServantHandler (runM ioAct))
  Handler $ either throwError return result

Can we avoid outside services with this approach? Sure! We can specify test interpretations of our effects that don’t use SqlPersistT or Redis. We’ll still have IO for reasons mentioned last week, but it’s still an easy change. We'll define separate runTestDatabase and runTestCache functions that use the same effects we saw last week. They’ll depend on using the State over our in-memory maps.

runTestDatabase :: 
  (Member (StateT (UserMap, ArticleMap, UserMap) IO) r) => 
  Eff (Database ': r) a -> 
  Eff r a
runTestDatabase = runNat databaseToState
  where
    databaseToState :: Database a -> StateT (UserMap, ArticleMap, UserMap) IO a
    …

runTestCache ::
  (Member (StateT (UserMap, ArticleMap, UserMap) IO) r) =>
  Eff (Cache ': r) a ->
  Eff r a
runTestCache = runNat cacheToState
  where
    cacheToState :: Cache a -> StateT (UserMap, ArticleMap, UserMap) IO a
    ...

Then we fill in the definitions with the same functions we used when writing our TestMonad. After that, we define another natural transformation, in the same pattern:

transformTestEffToHandler ::
  MVar (UserMap, ArticleMap, UserMap) ->
  Eff '[Cache, Database, StateT (UserMap, ArticleMap, UserMap) IO] :~> Handler
transformTestEffToHandler sharedMap = NT $ \action -> do
  let stateAct = (runTestDatabase . runTestCache) action
  result <- liftIO (runWithServantHandler (runEff stateAct))
  Handler $ either throwError return result
  where
    runEff :: Eff '[StateT (UserMap, ArticleMap, UserMap) IO] a -> IO a
    runEff action = do
      let stateAction = runM action
      runStateTWithPointer stateAction sharedMap

Incorporating our Interpretations

The final step we’ll take is to change a couple different type signatures within our API code. We’ll pass a new natural transformation to our Server function:

fullAPIServer :: 
  ((Eff '[Cache, Redis, Database, SqlPersistT (LoggingT IO), IO]) :~> Handler) ->
  Server FullAPI
fullAPIServer nt = ...

And then we’ll change all our handlers to use Eff with the proper members, instead of AppMonad:

fetchUsersHandler :: (Member Database r, Member Cache r) => Int64 -> Eff r User
createUserHandler :: (Member Database r) => User -> Eff r Int64
fetchArticleHandler :: (Member Database r) => Int64 -> Eff r Article
createArticleHandler :: (Member Database r) => Article -> Eff r Int64
fetchArticlesByAuthorHandler :: (Member Database r) => Int64 -> Eff r [KeyVal Article]
fetchRecentArticlesHandler :: (Member Database r) => Eff r [(KeyVal User, KeyVal Article)]

Conclusion

We’ve come a long way with our small application. It doesn’t do much. But it has served as a great launchpad for learning many interesting libraries and techniques. In particular, we’ve seen in these last few weeks how to organize effects within our application. With the Eff library, we can represent our effects with data types that we can re-order with ease.

If you’ve never tried Haskell before, give it a shot! Download our Getting Started Checklist and get going!

If you’ve done a little Haskell but aren’t set on your skills yet, maybe this article went over your head. That’s OK! You can work on your skills more with our Recursion Workbook!

Read More
James Bowen James Bowen

A Different Point of View: Interpreting our Monads Without Outside Services

Last week we updated our API to use some interesting monadic constructs. These allowed us to narrow down the places where effects could happen in our application. This week we’ll examine another advantage of this system. We’ll examine how we can simplify our tests and remove the dependency on outside services.

You can follow along this code by looking at the effects-2 branches on the Github repository. In effects-2-start, we’ve updated our tests to use the AppMonad instead of normal IO functions. We can still do better though (see the effects-2-end branch for the final product). We can create a second monad that implements our MonadDatabase and MonadCache classes. This creates what we call a different interpretation of our effects. We can do this in such a way that they don’t rely on running instances of Postgres and Redis.

Re-Imagining our Monad

Let’s imagine the simplest possible way to have a “database”. Instead of using a remote service, we could use in-memory maps. So let’s start with a couple type synonyms:

type UserMap = Map.Map Int64 User
type ArticleMap = Map.Map Int64 Article

There are three different maps in our application. The first map will be our normal Users table from the database. The second map will be the database’s Article table. The third map will refer to our Users cache. Now we’ll create a monad that links all these different elements together, and wraps them in StateT. We’ll then be able to update these maps between requests. We still need IO on our monad stack for reasons we’ll see later.

newtype TestMonad a = TestMonad (StateT (UserMap, ArticleMap, UserMap) IO a)
  deriving (Functor, Applicative, Monad)

instance MonadIO TestMonad where
  liftIO action = TestMonad $ liftIO action

Now we want to create instances of our database type classes for this monad. Let’s start an implementation of MonadDatabase by considering how we’ll fetch a user:

instance MonadDatabase TestMonad where
  fetchUserDB uid = ...

All we need to do is grab the first map out of our state tuple, and then use the normal Map lookup function! We can do the same with an article:

fetchUserDB uid = TestMonad $ do
  userDB <- (view _1) <$> get
  return $ Map.lookup uid userDB

fetchArticleDB aid = TestMonad $ do
  articleDB <- (view _2) <$> get
  return $ Map.lookup aid articleDB

Creating elements is a little more complicated, since we have to generate the keys. This isn’t that hard though! We’ll check if the map is empty and use 1 for the key if there are no entries. Otherwise find the max key and add 1 to it (note that the API for Map.findMax has changed since I wrote this) :

createUserDB user = TestMonad $ do
  (userDB, articleDB, userCache) <- get
  let newUid = if Map.null userDB
        then 1
        else 1 + (fst . Map.findMax) userDB
  ...

Now we’ll create a modified map by inserting our new element. Then we’ll put the modified map back in along with the other maps:

createUserDB user = TestMonad $ do
  (userDB, articleDB, userCache) <- get
  let newUid = if Map.null userDB
        then 1
        else 1 + (fst . Map.findMax) userDB
  let userDB' = Map.insert newUid user userDB
  put (userDB', articleDB, userCache)
  return newUid

createArticleDB article = TestMonad $ do
  (userDB, articleDB, userCache) <- get
  let newAid = if Map.null articleDB
        then 1
        else 1 + (fst . Map.findMax) articleDB
  let articleDB' = Map.insert newAid article articleDB
  put (userDB, articleDB', userCache)
  return newAid

Deletion follows the same general pattern. The only difference is we delete from the map instead of inserting!

deleteUserDB uid = TestMonad $ do
  (userDB, articleDB, userCache) <- get
  let userDB' = Map.delete uid userDB
  put (userDB', articleDB, userCache)

deleteArticleDB aid = TestMonad $ do
  (userDB, articleDB, userCache) <- get
  let articleDB' = Map.delete aid articleDB
  put (userDB, articleDB', userCache)

Now our final two functions will involve actually performing some application logic. To fetch articles by author, we get the list of articles in our database and filter it using the author ID:

fetchArticlesByAuthor uid = TestMonad $ do
  articleDB <- (view _2) <$> get
  return $ map KeyVal (filter articleByAuthor (Map.toList articleDB))
  where
    articleByAuthor (_, article) = articleAuthorId article == toSqlKey uid

For fetching the recent articles, we first sort all the articles in our map by timestamp. Then we take the ten most recent:

fetchRecentArticles = TestMonad $ do
  (userDB, articleDB, _) <- get
  let recentArticles = take 10 (sortBy orderByTimestamp (Map.toList articleDB)) 
  ...
  where
    orderByTimestamp (_, article1) (_, article2) =
      articlePublishedTime article2 `compare` articlePublishedTime article1

But now we have to match each of them with right user. This involves performing a lookup based on the user ID. But then we’re done!

fetchRecentArticles = TestMonad $ do
  (userDB, articleDB, _) <- get
  let recentArticles = take 10 (sortBy orderByTimestamp (Map.toList articleDB)) 
  return $ map (matchWithAuthor userDB) recentArticles
  where
    orderByTimestamp (_, article1) (_, article2) =
      articlePublishedTime article2 `compare` articlePublishedTime article1
    matchWithAuthor userDB (aid, article) =
      case Map.lookup (fromSqlKey (articleAuthorId article)) userDB of
        Nothing -> error "Found article with no user" 
        Just u -> (KeyVal (fromSqlKey (articleAuthorId article), u), KeyVal (aid, article))

Our instance for MonadCache is very similar. We'll manipulate the third map instead of the first 2:

instance MonadCache TestMonad where
  cacheUser uid user = TestMonad $ do
    (userDB, articleDB, userCache) <- get
    let userCache' = Map.insert uid user userCache
    put (userDB, articleDB, userCache')
  fetchCachedUser uid = TestMonad $ do
    userCache <- (view _3) <$> get
    return $ Map.lookup uid userCache
  deleteCachedUser uid = TestMonad $ do
    (userDB, articleDB, userCache) <- get
    let userCache' = Map.delete uid userCache
    put (userDB, articleDB, userCache')

Another Natural Transformation

Now we’re not quite done. We need the ability to run a version of our server that uses this interpretation of our effects. To do this, we need a natural transformation like we had before with AppMonad. Unfortunately, the StateT of our maps won’t get threaded through properly unless we use a pointer to it. This is why we need IO on our stack. Here’s a function that will use a pointer (MVar) to our state, run it, and then swap in the new map.

runStateTWithPointer :: (Exception e, MonadIO m) => StateT s m a -> MVar s -> m (Either e a)
runStateTWithPointer action ref = do
  env <- liftIO $ readMVar ref
  (val, newEnv) <- runStateT action env
  void $ liftIO $ swapMVar ref newEnv
  return $ Right val

Now for our transformation, we’ll take this pointer and run the state. Then we need to catch exceptions like we did in our transformation for AppMonad:

transformTestToHandler :: MVar (UserMap, ArticleMap, UserMap) -> TestMonad :~> Handler
transformTestToHandler sharedMap = NT $ \(TestMonad action) -> do
  result <- liftIO $ handleAny handler $
    runStateTWithPointer action sharedMap 
  Handler $ either throwError return result
  where
    handler :: SomeException -> IO (Either ServantErr a)
    handler e = return $ Left $ err500 { errBody = pack (show e) }

Now when we setup our tests, we’ll run our server using this transformation instead. Notice that we don’t have to do anything with Postgres or Redis here!

setupTests :: IO (ClientEnv, MVar (UserMap, ArticleMap, UserMap), ThreadId)
setupTests = do
  mgr <- newManager tlsManagerSettings
  baseUrl <- parseBaseUrl "http://127.0.0.1:8000"
  let clientEnv = ClientEnv mgr baseUrl
  let initialMap = (Map.empty, Map.empty, Map.empty)
  mapRef <- newMVar initialMap
  tid <- forkIO $
    run 8000 (serve usersAPI (testAPIServer (transformTestToHandler mapRef)))
  threadDelay 1000000
  return (clientEnv, mapRef, tid)

Now when our tests run, they’ll hit a server storing the information in memory instead of a Postgres server. This is super cool!

Integrating with our Tests

Unfortunately, it’s still a little awkward to write our tests. A lot of what they’re actually testing is the internal state of the “database” in question. So we need this function that takes the pointer to the map (the same pointer used by the server) and runs actions on it:

runTestMonad :: MVar (UserMap, ArticleMap, UserMap) -> TestMonad a -> IO a
runTestMonad mapVar (TestMonad action) = do
  currentState <- readMVar mapVar
  (result, newMap) <- runStateT action currentState
  swapMVar mapVar newMap
  return result

Now in our tests, we’ll wrap any calls to the database with this action. Here’s an example of our first before hook:

beforeHook1 :: ClientEnv -> MVar (UserMap, ArticleMap, UserMap) -> IO (Bool, Bool, Bool)
beforeHook1 clientEnv mapVar = do
  callResult <- runClientM (fetchUserClient 1) clientEnv
  let throwsError = isLeft callResult
  (inPG, inRedis) <- runTestMonad mapVar $ do
    inPG <- isJust <$> fetchUserDB 1
    inRedis <- isJust <$> fetchCachedUser 1
    return (inPG, inRedis)
  return (throwsError, inPG, inRedis)

One excellent consequence of using an in-memory map is that we don’t care if there’s data in our “database” at the end. Thus we can completely get rid of our after hooks, which were a bit of a pain!

main :: IO ()
main = do
  (clientEnv, dbMap, tid) <- setupTests
  hspec $ before (beforeHook1 clientEnv dbMap) spec1
  hspec $ before (beforeHook2 clientEnv dbMap) spec2
  hspec $ before (beforeHook3 clientEnv dbMap) spec3
  hspec $ before (beforeHook4 clientEnv dbMap) spec4
  hspec $ before (beforeHook5 clientEnv dbMap) spec5
  hspec $ before (beforeHook6 clientEnv dbMap) spec6
  killThread tid 
  return ()

And now our tests also run perfectly well without needing the docker container to be active! Hooray!

Conclusion

There’s a certain argument that we haven’t really accomplished much. Our app is very shallow, and most of the logic happens within the database calls themselves. Recall that many of our handler functions reduced to the database calls. Hence, the only thing we’re testing right now is our test interpretation!

But it’s easy to imagine that if our application were more complicated, this logic wouldn’t be at the core of our code. In most cases, database queries are the prelude to manipulating the data. And this TestMonad would remove the inconvenience of sourcing that data from outside.

Stay tuned for next week, where we’ll wrap up this consideration of effects by looking at free monads! We’ll consider the “freer-effects” library. It will let us cut down a bit on some of the boilerplate we get with this MTL style approach.

Never tried Haskell before? Do you have visions of conquering all foes with these sorts of abstractions? Check out our Getting Started Checklist and start your journey!

Have you dabbled a little but want to test your skills some more? Take a look at our Recursion Workbook!

Read More
James Bowen James Bowen

Organizing our Effects Effectively

In the last 5 weeks or so, we’ve built a web application exposing a small API. The application is quite narrow, encompassing only a small amount of functionality. But it is still deep, covering several different libraries and techniques.

In these next couple weeks, we’ll look at some architectural considerations. We’ll observe some of the weaknesses of this system, and how we can improve on them. This week will focus on an approach with type classes and monad transformers. In a couple weeks, we’ll consider free monads, and how we can use them.

You can follow along with this code on the effects-1 branch of the Github repo.

Weaknesses

In our current system, there are a lot of different functions like these:

fetchUserPG :: PGInfo -> Int64 -> IO (Maybe User)
createUserPG :: PGInfo -> User -> IO Int64
cacheUser :: RedisInfo -> Int64 -> User -> IO ()

Now, the parameters do inform us what each function should be accessing. But the functions are still regular IO functions. This means a novice programmer could come in and get the idea that it’s fine to use arbitrary effects. For instance, why not fetch our Postgres information from the Redis function? After all, fetchPGInfo is an IO function as well:

fetchPostgresConnection :: IO PGInfo
...

cacheUser :: RedisInfo -> Int64 -> User -> IO ()
cacheUser = do
  pgInfo <- fetchPostgresConnection
  -- Connect to Postgres instead of Redis :(

Our API also has some uncomfortable lifting in our handler functions. We have to call liftIO because all our database functions are IO functions.

fetchUsersHandler :: PGInfo -> RedisInfo -> Int64 -> Handler User
fetchUsersHandler pgInfo redisInfo uid = do
  -- liftIO #1
  maybeCachedUser <- liftIO $ fetchUserRedis redisInfo uid
  case maybeCachedUser of
    Just user -> return user
    Nothing -> do
      -- liftIO #2
      maybeUser <- liftIO $ fetchUserPG pgInfo uid
      case maybeUser of
        -- liftIO #3
        Just user -> liftIO (cacheUser redisInfo uid user) >> return user
        Nothing -> Handler $ (throwE $ err401 { errBody = "Could not find user with that ID" })

At the very least, our connection parameters are explicit here. If we hid them in a Reader, this would introduce even more lifts.

This article will focus on using type classes to restrict how we use effects. With any luck, we'll also clean up our code a bit and make it easier to test things. But we’ll focus more on testing more next week.

Now, depending on the project size and scope, these weaknesses might not be issues. But it’s definitely a useful exercise to see alternative ways to organize our code.

Defining our Type Classes

Our first step for limiting our effects will be to create two type classes. We'll have one for our main database, and one for our cache. We'll try to make these functions agnostic to the underlying database representation. Hence, we’ll change our API to remove the notion of Entity. We’ll replace it with the idea of KeyVal, a wrapper around a tuple.

newtype KeyVal a = KeyVal (Int64, a)

With that, here are the 8 functions we have for accessing our database:

class (Monad m) => MonadDatabase m where
  fetchUserDB :: Int64 -> m (Maybe User) 
  createUserDB :: User -> m Int64 
  deleteUserDB :: Int64 -> m ()
  fetchArticleDB :: Int64 -> m (Maybe Article)
  createArticleDB :: Article -> m Int64
  deleteArticleDB :: Int64 -> m ()
  fetchArticlesByAuthor :: Int64 -> m [KeyVal Article]
  fetchRecentArticles :: m [(KeyVal User, KeyVal Article)]

And then we have three functions for how we interact with our cache:

class (Monad m) => MonadCache m where
  cacheUser :: Int64 -> User -> m ()
  fetchCachedUser :: Int64 -> m (Maybe User)
  deleteCachedUser :: Int64 -> m ()

We can now create instances of these type classes for any different monad we want to use. Let’s start by describing implementations for our existing libraries.

Writing Instances

We’ll start with SqlPersistT. We want to make an instance of MonadDatabase for it. We'll gather all the different functionality from the last few articles.

instance (MonadIO m, MonadLogger m) => MonadDatabase (SqlPersistT m) where
  fetchUserDB uid = get (toSqlKey uid)

  createUserDB user = fromSqlKey <$> insert user

  deleteUserDB uid = delete (toSqlKey uid :: Key User)

  fetchArticleDB aid = ((fmap entityVal) . listToMaybe) <$> (select . from $ \articles -> do
    where_ (articles ^. ArticleId ==. val (toSqlKey aid))
    return articles)

  createArticleDB article = fromSqlKey <$> insert article

  deleteArticleDB aid = delete (toSqlKey aid :: Key Article)

  fetchArticlesByAuthor uid = do
    entities <- select . from $ \articles -> do
      where_ (articles ^. ArticleAuthorId ==. val (toSqlKey uid))
      return articles
    return $ unEntity <$> entities

  fetchRecentArticles = do
    tuples <- select . from $ \(users `InnerJoin` articles) -> do
      on (users ^. UserId ==. articles ^. ArticleAuthorId)
      orderBy [desc (articles ^. ArticlePublishedTime)]
      limit 10
      return (users, articles)
    return $ (\(userEntity, articleEntity) -> (unEntity userEntity, unEntity articleEntity)) <$> tuples

Since we’re removing Entity from our API, we use this unEntity function. It will give us back the key and value as a KeyVal:

unEntity :: (ToBackendKey SqlBackend a) => Entity a -> KeyVal a
unEntity (Entity id_ val_) = KeyVal (fromSqlKey id_, val_)

Now we’ll do the same with our cache functions. We’ll make an instance of MonadCache for the Redis monad:

instance MonadCache Redis where
  cacheUser uid user = void $ setex (pack . show $ uid) 3600 (pack . show $ user)
  fetchCachedUser uid = do
    result <- get (pack . show $ uid)
    case result of
      Right (Just userString) -> return $ Just (read . unpack $ userString)
      _ -> return Nothing
  deleteCachedUser uid = void $ del [pack . show $ uid]

And that’s all there is here! Let’s see how we can combine these for easy use within our API.

Making our App Monad

We’d like to describe an “App Monad” that will allow us to access both these functionalities with ease. We’ll make a wrapper around a monad transformer incorporating a Reader for the Redis information and the SqlPersistT monad. We derive Monad for this type using GeneralizedNewtypeDeriving:

{-# LANGUAGE GeneralizedNewtypeDeriving #-}

newtype AppMonad a = AppMonad (ReaderT RedisInfo (SqlPersistT (LoggingT IO)) a)
  deriving (Functor, Applicative, Monad)

Now we’ll want to make instances of MonadDatabase and MonadCache. The instances are easy though; we'll use the instances for the underlying monads. First, let's define a transformation from an SqlPersistT action to our AppMonad. We need to build out the ReaderT RedisInfo for this. We'll use the ReaderT constructor and ignore the info with const.

liftSqlPersistT :: SqlPersistT (LoggingT IO) a -> AppMonad a
liftSqlPersistT action = AppMonad $ ReaderT (const action)

We can also define a transformation on Redis actions:

liftRedis :: Redis a -> AppMonad a
liftRedis action = do
  info <- AppMonad ask
  connection <- liftIO $ connect info
  liftIO $ runRedis connection action

We'll apply our underlying instances like so:

instance MonadDatabase AppMonad where
  fetchUserDB = liftSqlPersistT . fetchUserDB
  createUserDB = liftSqlPersistT . createUserDB
  deleteUserDB = liftSqlPersistT . deleteUserDB
  fetchArticleDB = liftSqlPersistT . fetchArticleDB
  createArticleDB = liftSqlPersistT . createArticleDB
  deleteArticleDB = liftSqlPersistT . deleteArticleDB
  fetchArticlesByAuthor = liftSqlPersistT . fetchArticlesByAuthor
  fetchRecentArticles = liftSqlPersistT fetchRecentArticles

instance MonadCache AppMonad where
  cacheUser uid user = liftRedis (cacheUser uid user)
  fetchCachedUser = liftRedis . fetchCachedUser 
  deleteCachedUser = liftRedis . deleteCachedUser

And that's it! We have our instances. Now we want to move on and figure out how we’ll actually incorporate this new monad into our API.

Writing a Natural Transformation

We would like to make it so that our handler functions can use AppMonad instead of the Handler monad. But Servant is sort’ve hard-coded to use Handler, so what do we do? The answer is we define a “Natural Transformation”.

I found this term to be a bit like "category". It seems innocuous but actually refers to something deeply mathematical. But we don't need to know too much to use it. The type operator (:~>) defines a natural transformation. All we need to make it is a function that takes an action in our monad and converts it into an action in the Handler monad. We'll need to pass our connection information to make this work.

transformAppToHandler :: PGInfo -> RedisInfo -> AppMonad :~> Handler

We’ll start by defining a “handler” that will catch any errors we throw and recast them as Servant errors. In general, you want to list the specific types of exceptions you’ll catch. It's not a great idea to catch every exception like this. But for this example, we’ll keep it simple:

handler :: SomeException -> IO (Either ServantErr a)
handler e = return $ Left $ err500 { errBody = pack (show e)}

Notice this returns an Either which is always a Left. Let's now define how we convert an action from our “AppMonad” into an Either as well. We’ll get the result and pass it on as a Right value.

runAppAction :: Exception e => AppMonad a -> IO (Either e a)
runAppAction (AppMonad action) = do
  result <- runPGAction pgInfo $ runReaderT action redisInfo
  return $ Right result

And putting it together, here’s our transformation. We catch errors, and then wrap the result up in Handler.

transformAppToHandler :: PGInfo -> RedisInfo -> AppMonad :~> Handler
transformAppToHandler pgInfo redisInfo = NT $ \action -> do
  result <- liftIO (handleAny handler (runAppAction action))
  Handler $ either throwError return result
  ...

Incorporating the App Monad

All we have to do now is incorporate our new monad into our handlers. First off, let’s change our API to remove Entities:

type FullAPI =
       "users" :> Capture "userid" Int64 :> Get '[JSON] User
  :<|> "users" :> ReqBody '[JSON] User :> Post '[JSON] Int64
  :<|> "articles" :> Capture "articleid" Int64 :> Get '[JSON] Article
  :<|> "articles" :> ReqBody '[JSON] Article :> Post '[JSON] Int64
  :<|> "articles" :> "author" :> Capture "authorid" Int64 :> Get '[JSON] [KeyVal Article]
  :<|> "articles" :> "recent" :> Get '[JSON] [(KeyVal User, KeyVal Article)]

We want to update the type of each function. The AppMonad incorporates all the configuration information. So we don’t need to pass connection information explicitly. Instead, we can use constraints on our monad type classes to expose those effects. Here’s what our type signatures look like:

fetchUsersHandler :: (MonadDatabase m, MonadCache m) => Int64 -> m User
createUserHandler :: (MonadDatabase m) => User -> m Int64
fetchArticleHandler :: (MonadDatabase m) => Int64 -> m Article
createArticleHandler :: (MonadDatabase m)=> Article -> m Int64
fetchArticlesByAuthorHandler :: (MonadDatabase m) => Int64 -> m [KeyVal Article]
fetchRecentArticlesHandler :: (MonadDatabase m) => m [(KeyVal User, KeyVal Article)]

And now a lot of our functions are simple monadic calls. We don’t even need to use “lift”!

createUserHandler :: (MonadDatabase m) => User -> m Int64
createUserHandler = createUserDB

createArticleHandler :: (MonadDatabase m)=> Article -> m Int64
createArticleHandler = createArticleDB

fetchArticlesByAuthorHandler :: (MonadDatabase m) => Int64 -> m [KeyVal Article]
fetchArticlesByAuthorHandler = fetchArticlesByAuthor

fetchRecentArticlesHandler :: (MonadDatabase m) => m [(KeyVal User, KeyVal Article)]
fetchRecentArticlesHandler = fetchRecentArticles

The “fetch” functions are a bit more complicated since we’ll want to do stuff like check the cache first. But again, all our functions are simple monadic calls without using any lifting. Here’s how our fetch handlers look:

fetchUsersHandler :: (MonadDatabase m, MonadCache m) => Int64 -> m User
fetchUsersHandler uid = do
  maybeCachedUser <- fetchCachedUser uid
  case maybeCachedUser of
    Just user -> return user
    Nothing -> do
      maybeUser <- fetchUserDB uid
      case maybeUser of
        Just user -> cacheUser uid user >> return user
        Nothing -> error "Could not find user with that ID"

fetchArticleHandler :: (MonadDatabase m) => Int64 -> m Article
fetchArticleHandler aid = do
  maybeArticle <- fetchArticleDB aid
  case maybeArticle of
    Just article -> return article
    Nothing -> error "Could not find article with that ID"

And now we’ll change our Server function. We’ll update it so that it takes our natural transformation as an argument. Then we’ll use the enter function combined with that transformation. This is how Servant knows what monad we want for our handlers:

fullAPIServer :: (AppMoand :~> Handler) -> Server FullAPI
fullAPIServer naturalTransformation =
  enter naturalTransformation $
    fetchUsersHandler :<|>
    createUserHandler :<|>
    fetchArticleHandler :<|>
    createArticleHandler :<|>
    fetchArticlesByAuthorHandler :<|>
    fetchRecentArticlesHandler

runServer :: IO ()
runServer = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  -- Pass the natural transformation as an argument!
  run 8000 (serve usersAPI (fullAPIServer (transformAppToHandler pgInfo redisInfo)))

And now we’re done!

Weaknesses with this Approach

Of course, this system is not without it’s weaknesses. In particular, there’s quite a bit of boilerplate. This is especially true if we don’t want to fix the ordering of our monad stack. For instance what if another part of our application puts SqlPersistT on top of Redis? What if we want to mix other monad transformers in? We’ll need new instances of MonadDatabase and MonadCache for that. We'll end up writing a lot more simple definitions. We’ll examine solutions to this weakness in a couple weeks when we look at free monads.

We’ll also need to add new functions to our type classes every time we want to update their functionality. And then we’ll have to update EVERY instance of that typeclass, which can be quite a pain. The more instances we have, the more painful it will be to add new functionality.

Conclusion

So with a few useful tricks, we can come up with code that is a lot cleaner. We employed type classes to great effect to limit how effects appear in our application. By writing instances of these classes for different monads, we can change the behavior of our application. Next week, we’ll see how we can use this behavior to write simpler tests!

When managing an application with this many dependencies you need the right tools. I used Stack for all my Haskell project organization. Check out our free Stack mini-course to learn more!

But if you’ve never tried Haskell at all, give it a try! Take a look at our Getting Started Checklist.

Read More
James Bowen James Bowen

Join the Club: Type-safe Joins with Esqueleto!

In the last four articles or so, we’ve done a real whirlwind tour of Haskell libraries. We created a database schema using Persistent and used it to write basic SQL queries in a type-safe way. We saw how to expose this database via an API with Servant. We also went ahead and added some caching to that server with Redis. Finally, we wrote some basic tests around the behavior of this API. By using Docker, we made those tests reproducible.

In this article, we’re going to review this whole process by adding another type to our schema. We’ll write some new endpoints for an Article type, and link this type to our existing User type with a foreign key. Then we’ll learn one more library: Esqueleto. Esqueleto improves on Persistent by allowing us to write type-safe SQL joins.

As with the previous articles, there’s a specific branch on the Github repository for this series. Go there and take a look at the esqueleto branch to see the complete code for this article.

Adding Article to our Schema

So our first step is to extend our schema with our Article type. We’re going to give each article a title, some body text, and a timestamp for its publishing time. One new feature we’ll see is that we’ll add a foreign key referencing the user who wrote the article. Here’s what it looks like within our schema:

PTH.share [PTH.mkPersist PTH.sqlSettings, PTH.mkMigrate "migrateAll"] [PTH.persistLowerCase|
 User sql=users
   ...

 Article sql=articles
   title Text
   body Text
   publishedTime UTCTime
   authorId UserId
   UniqueTitle title
   deriving Show Read Eq
|]

We can use UserId as a type in our schema. This will create a foreign key column when we create the table in our database. In practice, our Article type will look like this when we use it in Haskell:

data Article = Article
 { articleTitle :: Text
 , articleBody :: Text
 , articlePublishedTime :: UTCTime
 , articleAuthorId :: Key User
 }

This means it doesn’t reference the entire user. Instead, it contains the SQL key of that user. Since we’ll be adding the article to our API, we need to add ToJSON and FromJSON instances as well. These are pretty basic as well, so you can check them out here if you’re curious. If you’re curious about JSON instances in general, take a look at this article.

Adding Endpoints

Now we’re going to extend our API to expose certain information about these articles. First, we’ll write a couple basic endpoints for creating an article and then fetching it by its ID:

type FullAPI = 
      "users" :> Capture "userid" Int64 :> Get '[JSON] User
 :<|> "users" :> ReqBody '[JSON] User :> Post '[JSON] Int64
 :<|> "articles" :> Capture "articleid" Int64 :> Get '[JSON] Article
 :<|> "articles" :> ReqBody '[JSON] Article :> Post '[JSON] Int64

Now, we’ll write a couple special endpoints. The first will take a User ID as a key and then it will provide all the different articles the user has written. We’ll do this endpoint as /articles/author/:authorid.

...
 :<|> "articles" :> "author" :> Capture "authorid" Int64 :> Get '[JSON] [Entity Article]

Our last endpoint will fetch the most recent articles, up to a limit of 10. This will take no parameters and live at the /articles/recent route. It will return tuples of users and their articles, both as entities.

…
 :<|> "articles" :> "recent" :> Get '[JSON] [(Entity User, Entity Article)]

Adding Queries (with Esqueleto!)

Before we can actually implement these endpoints, we’ll need to write the basic queries for them. For creating an article, we use the standard Persistent insert function:

createArticlePG :: PGInfo -> Article -> IO Int64
createArticlePG connString article = fromSqlKey <$> runAction connString (insert article)

We could do the same for the basic fetch endpoint. But we’ll write this basic query using Esqueleto in the interest of beginning to learn the syntax. With Persistent, we used list parameters to specify different filters and SQL operations. Esqueleto instead uses a special monad to compose the different type of query. The general format of an esqueleto select call will look like this:

fetchArticlePG :: PGInfo -> Int64 -> IO (Maybe Article)
fetchArticlePG connString aid = runAction connString selectAction
 where
   selectAction :: SqlPersistT (LoggingT IO) (Maybe Article)
   selectAction = select . from $ \articles -> do
     ...

We use select . from and then provide a function that takes a table variable. Our first queries will only refer to a single table, but we'll see a join later. To complete the function, we’ll provide the monadic action that will incorporate the different parts of our query.

The most basic filtering function we can call from within this monad is where_. This allows us to provide a condition on the query, much as we could with the filter list from Persistent.

selectAction :: SqlPersistT (LoggingT IO) (Maybe Article)
   selectAction = select . from $ \articles -> do
     where_ (articles ^. ArticleId ==. val (toSqlKey aid))

First, we use the ArticleId lens to specify which value of our table we’re filtering. Then we specify the value to compare against. We not only need to lift our Int64 into an SqlKey, but we also need to lift that value using the val function.

But now that we’ve added this condition, all we need to do is return the table variable. Now, select returns our results in a list. But since we’re searching by ID, we only expect one result. We’ll use listToMaybe so we only return the head element if it exists. We’ll also use entityVal once again to unwrap the article from its entity.

selectAction :: SqlPersistT (LoggingT IO) (Maybe Article)
   selectAction = ((fmap entityVal) . listToMaybe) <$> (select . from $ \articles -> do
     where_ (articles ^. ArticleId ==. val (toSqlKey aid))
     return articles)

Now we should know enough that we can write out the next query. It will fetch all the articles that have written by a particular user. We’ll still be querying on the articles table. But now instead checking the article ID, we’ll make sure the ArticleAuthorId is equal to a certain value. Once again, we’ll lift our Int64 user key into an SqlKey and then again with val to compare it in “SQL-land”.

fetchArticleByAuthorPG :: PGInfo -> Int64 -> IO [Entity Article]
fetchArticleByAuthorPG connString uid = runAction connString fetchAction
 where
   fetchAction :: SqlPersistT (LoggingT IO) [Entity Article]
   fetchAction = select . from $ \articles -> do
     where_ (articles ^. ArticleAuthorId ==. val (toSqlKey uid))
     return articles

And that’s the full query! We want a list of entities this time, so we’ve taken out listToMaybe and entityVal.

Now let’s write the final query, where we’ll find the 10 most recent articles regardless of who wrote them. We’ll include the author along with each article. So we’re returning a list of of these different tuples of entities. This query will involve our first join. Instead of using a single table for this query, we’ll use the InnerJoin constructor to combine our users table with the articles table.

fetchRecentArticlesPG :: PGInfo -> IO [(Entity User, Entity Article)]
fetchRecentArticlesPG connString = runAction connString fetchAction
 where
   fetchAction :: SqlPersistT (LoggingT IO) [(Entity User, Entity Article)]
   fetchAction = select . from $ \(users `InnerJoin` articles) -> do

Since we’re joining two tables together, we need to specify what columns we’re joining on. We’ll use the on function for that:

fetchAction :: SqlPersistT (LoggingT IO) [(Entity User, Entity Article)]
   fetchAction = select . from $ \(users `InnerJoin` articles) -> do
     on (users ^. UserId ==. articles ^. ArticleAuthorId)

Now we’ll order our articles based on the timestamp of the article using orderBy. The newest articles should come first, so we'll use a descending order. Then we limit the number of results with the limit function. Finally, we’ll return both the users and the articles, and we’re done!

fetchAction :: SqlPersistT (LoggingT IO) [(Entity User, Entity Article)]
   fetchAction = select . from $ \(users `InnerJoin` articles) -> do
     on (users ^. UserId ==. articles ^. ArticleAuthorId)
     orderBy [desc (articles ^. ArticlePublishedTime)]
     limit 10
     return (users, articles)

Caching Different Types of Items

We won’t go into the details of caching our articles in Redis, but there is one potential issue we want to observe. Currently we’re using a user’s SQL key as their key in our Redis store. So for instance, the string “15” could be such a key. If we try to naively use the same idea for our articles, we’ll have a conflict! Trying to store an article with ID “15” will overwrite the entry containing the User!

But the way around this is rather simple. What we would do is that for the user’s key, we would make the string something like users:15. Then for our article, we’ll have its key be articles:15. As long as we deserialize it the proper way, this will be fine.

Filling in the Server handlers

Now that we’ve written our database query functions, it is very simple to fill in our Server handlers. Most of them boil down to following the patterns we’ve already set with our other two endpoints:

fetchArticleHandler :: PGInfo -> Int64 -> Handler Article
fetchArticleHandler pgInfo aid = do
 maybeArticle <- liftIO $ fetchArticlePG pgInfo aid
 case maybeArticle of
   Just article -> return article
   Nothing -> Handler $ (throwE $ err401 { errBody = "Could not find article with that ID" })

createArticleHandler :: PGInfo -> Article -> Handler Int64
createArticleHandler pgInfo article = liftIO $ createArticlePG pgInfo article

fetchArticlesByAuthorHandler :: PGInfo -> Int64 -> Handler [Entity Article]
fetchArticlesByAuthorHandler pgInfo uid = liftIO $ fetchArticlesByAuthorPG pgInfo uid

fetchRecentArticlesHandler :: PGInfo -> Handler [(Entity User, Entity Article)]
fetchRecentArticlesHandler pgInfo = liftIO $ fetchRecentArticlesPG pgInfo

Then we’ll complete our Server FullAPI like so:

fullAPIServer :: PGInfo -> RedisInfo -> Server FullAPI
fullAPIServer pgInfo redisInfo =
 (fetchUsersHandler pgInfo redisInfo) :<|>
 (createUserHandler pgInfo) :<|>
 (fetchArticleHandler pgInfo) :<|>
 (createArticleHandler pgInfo) :<|>
 (fetchArticlesByAuthorHandler pgInfo) :<|>
 (fetchRecentArticlesHandler pgInfo)

One interesting thing we can do is that we can compose our API types into different sections. For instance, we could separate our FullAPI into two parts. First, we could have the UsersAPI type from before, and then we could make a new type for ArticlesAPI. We can glue these together with the e-plus operator just as we could individual endpoints!

type FullAPI = UsersAPI :<|> ArticlesAPI

type UsersAPI =
      "users" :> Capture "userid" Int64 :> Get '[JSON] User
 :<|> "users" :> ReqBody '[JSON] User :> Post '[JSON] Int64

type ArticlesAPI =
 "articles" :> Capture "articleid" Int64 :> Get '[JSON] Article
 :<|> "articles" :> ReqBody '[JSON] Article :> Post '[JSON] Int64
 :<|> "articles" :> "author" :> Capture "authorid" Int64 :> Get '[JSON] [Entity Article]
 :<|> "articles" :> "recent" :> Get '[JSON] [(Entity User, Entity Article)]

If we do this, we’ll have to make similar adjustments in other areas combining the endpoints. For example, we would need to update the server handler joining and the client functions.

Writing Tests

Since we already have some user tests, it would also be good to have a few tests on the Articles section of the API. We’ll add one simple test around creating an article and then fetching it. Then we’ll add one test each for the "articles-by-author" and "recent articles" endpoints.

So one of the tricky parts of filling in this section will be that we need to make test Article object. But we'll need them to be functions on the User ID. This is because we can’t know a priori what SQL IDs we'll get when we insert the users into the database. But we can fill in all the other fields, including the published time. Here’s one example, but we’ll have a total of 18 different “test” articles.

testArticle1 :: Int64 -> Article
testArticle1 uid = Article
 { articleTitle = "First post"
 , articleBody = "A great description of our first blog post body."
 , articlePublishedTime = posixSecondsToUTCTime 1498914000
 , articleAuthorId = toSqlKey uid
 }

-- 17 other articles and some test users as well
…

Our before hooks will create all these different entities in the database. In general, we’ll go straight to the database without calling the API itself. Like with our users tests, we'll want to delete any database items we create. Let's write a generic after-hook that will take user IDs and article IDs and delete them from our database:

deleteArtifacts :: PGInfo -> RedisInfo -> [Int64] -> [Int64] -> IO ()
deleteArtifacts pgInfo redisInfo users articles = do
 void $ forM articles $ \a -> deleteArticlePG pgInfo a
 void $ forM users $ \u -> do
   deleteUserCache redisInfo u
   deleteUserPG pgInfo u

It’s important we delete the articles first! If we delete the users first, we'll encounter foreign key exceptions!

Our basic create-and-fetch test looks a lot like the previous user tests. We test the success of the response and that the new article lives in Postgres as we expect.

beforeHook4 :: ClientEnv -> PGInfo -> IO (Bool, Bool, Int64, Int64)
beforeHook4 clientEnv pgInfo = do
 userKey <- createUserPG pgInfo testUser2
 articleKeyEither <- runClientM (createArticleClient (testArticle1 userKey)) clientEnv
 case articleKeyEither of
   Left _ -> error "DB call failed on spec 4!"
   Right articleKey -> do
     fetchResult <- runClientM (fetchArticleClient articleKey) clientEnv
     let callSucceeds = isRight fetchResult
     articleInPG <- isJust <$> fetchArticlePG pgInfo articleKey
     return (callSucceeds, articleInPG, userKey, articleKey)

spec4 :: SpecWith (Bool, Bool, Int64, Int64)
spec4 = describe "After creating and fetching an article" $ do
 it "The fetch call should return a result" $ \(succeeds, _, _, _) -> succeeds `shouldBe` True
 it "The article should be in Postgres" $ \(_, inPG, _, _) -> inPG `shouldBe` True

afterHook4 :: PGInfo -> RedisInfo -> (Bool, Bool, Int64, Int64) -> IO ()
afterHook4 pgInfo redisInfo (_, _, uid, aid) = deleteArtifacts pgInfo redisInfo [uid] [aid]

Our next test will create two different users and several different articles. We'll first insert the users and get their keys. Then we can use these keys to create the articles. We create five articles in this test. We assign three to the first user, and two to the second user:

beforeHook5 :: ClientEnv -> PGInfo -> IO ([Article], [Article], Int64, Int64, [Int64])
beforeHook5 clientEnv pgInfo = do
 uid1 <- createUserPG pgInfo testUser3
 uid2 <- createUserPG pgInfo testUser4
 articleIds <- mapM (createArticlePG pgInfo)
   [ testArticle2 uid1, testArticle3 uid1, testArticle4 uid1
   , testArticle5 uid2, testArticle6 uid2 ]
 ...

Now we want to test that we when call the articles-by-user endpoint, we only get the right articles. We’ll return each group of articles, the user IDs, and the list of article IDs:

beforeHook5 :: ClientEnv -> PGInfo -> IO ([Article], [Article], Int64, Int64, [Int64])
beforeHook5 clientEnv pgInfo = do
 uid1 <- createUserPG pgInfo testUser3
 uid2 <- createUserPG pgInfo testUser4
 articleIds <- mapM (createArticlePG pgInfo)
   [ testArticle2 uid1, testArticle3 uid1, testArticle4 uid1
   , testArticle5 uid2, testArticle6 uid2 ]
 firstArticles <- runClientM (fetchArticlesByAuthorClient uid1) clientEnv
 secondArticles <- runClientM (fetchArticlesByAuthorClient uid2) clientEnv
 case (firstArticles, secondArticles) of
   (Right as1, Right as2) -> return (entityVal <$> as1, entityVal <$> as2, uid1, uid2, articleIds)
   _ -> error "Spec 5 failed!"

Now we can write the assertion itself, testing that the articles returned are what we expect.

spec5 :: SpecWith ([Article], [Article], Int64, Int64, [Int64])
spec5 = describe "When fetching articles by author ID" $ do
 it "Fetching by the first author should return 3 articles" $ \(firstArticles, _, uid1, _, _) ->
   firstArticles `shouldBe` [testArticle2 uid1, testArticle3 uid1, testArticle4 uid1]
 it "Fetching by the second author should return 2 articles" $ \(_, secondArticles, _, uid2, _) ->
   secondArticles `shouldBe` [testArticle5 uid2, testArticle6 uid2]

We would then follow that up with a similar after hook.

The final test will follow a similar pattern. Only this time, we’ll be checking the combinations of users and articles. We’ll also make sure to include 12 different articles to test that the API limits results to 10.

beforeHook6 :: ClientEnv -> PGInfo -> IO ([(User, Article)], Int64, Int64, [Int64])
beforeHook6 clientEnv pgInfo = do
 uid1 <- createUserPG pgInfo testUser5
 uid2 <- createUserPG pgInfo testUser6
 articleIds <- mapM (createArticlePG pgInfo)
   [ testArticle7 uid1, testArticle8 uid1, testArticle9 uid1, testArticle10 uid2
   , testArticle11 uid2, testArticle12 uid1, testArticle13 uid2, testArticle14 uid2
   , testArticle15 uid2, testArticle16 uid1, testArticle17 uid1, testArticle18 uid2
   ]
 recentArticles <- runClientM fetchRecentArticlesClient clientEnv
 case recentArticles of
   Right as -> return (entityValTuple <$> as, uid1, uid2, articleIds)
   _ -> error "Spec 6 failed!"
 where
   entityValTuple (Entity _ u, Entity _ a) = (u, a)

Our spec will check that the list of 10 articles we get back matches our expectations. Then, as always, we remove the entities from our database.

Now we call these tests with our other tests, with small wrappers to call the hooks:

main :: IO ()
main = do
 ...
 hspec $ before (beforeHook4 clientEnv pgInfo) $ after (afterHook4 pgInfo redisInfo) $ spec4
 hspec $ before (beforeHook5 clientEnv pgInfo) $ after (afterHook5 pgInfo redisInfo) $ spec5
 hspec $ before (beforeHook6 clientEnv pgInfo) $ after (afterHook6 pgInfo redisInfo) $ spec6

And now we’re done! The tests pass!

…
After creating and fetching an article
 The fetch call should return a result
 The article should be in Postgres

Finished in 0.1698 seconds
2 examples, 0 failures

When fetching articles by author ID
 Fetching by the first author should return 3 articles
 Fetching by the second author should return 2 articles

Finished in 0.4944 seconds
2 examples, 0 failures

When fetching recent articles
 Should fetch exactly the 10 most recent articles

Conclusion

This completes our overview of useful production libraries. Over these articles, we’ve constructed a small web API from scratch. We’ve seen some awesome abstractions that let us deal with only the most important pieces of the project. Both Persistent and Servant generated a lot of extra boilerplate for us. This article showed the power of the Esqueleto library in allowing us to do type-safe joins. We also saw an end-to-end process of adding a new type and endpoints to our API.

In the coming weeks, we’ll be dealing with some more issues that can arise when building these kinds of systems. In particular, we’ll see how we can use alternative monads on top of Servant. Doing this can present certain issues that we'll explore. We’ll culminate by exploring the different approaches to encapsulating effects.

Be sure to check out our Haskell Stack mini-course!! It'll show you how to use Stack, so you can incorproate all the libraries from this series!

If you’re new to Haskell and not ready for that yet, take a look at our Getting Started Checklist and get going!

Read More
James Bowen James Bowen

Tangled Webs: Testing an Integrated System

In the last few articles, we’ve combined several useful Haskell libraries to make a small web app. We used Persistent to create a schema with automatic migrations for our database. Then we used Servant to expose this database as an API through a couple simple queries. Finally, we used Redis to act as a cache so that repeated requests for the same user happen faster.

For our next step, we’ll tackle a thorny issue: testing. How do we test a system that has so many moving parts? There are a couple different general approaches we could take. On one end of the spectrum we could mock out most of our API calls and services. This helps gives our testing deterministic behavior. This is desirable since we would want to tie deployments to test results. But we also want to be faithfully testing our API. So on the other end, there’s the approach we’ll try in this article. We'll set up functions to run our API and other services, and then use before and after hooks to use them.

Creating Client Functions for Our API

Calling our API from our tests means we’ll want a way to make API calls programmatically. We can do this with amazing ease with a Servant API by using the servant-client library. This library has one main function: client. This function takes a proxy for our API and generates programmatic client functions. Let's remember our basic endpoint types (after resolving connection information parameters):

fetchUsersHandler :: Int64 -> Handler User
createUserHandler :: User -> Handler Int64

We’d like to be able to call these API’s with functions that use the same parameters. Those types might look something like this:

fetchUserClient :: Int64 -> m User
createUserClient :: User -> m Int64

Where m is some monad. And in this case, the ServantClient library provides such a monad, ClientM. So let’s re-write these type signatures, but leave them seemingly unimplemented:

fetchUserClient :: Int64 -> ClientM User
createUserClient :: User -> ClientM Int64

Now we’ll construct a pattern match that combines these function names with the :<|> operator. As always, we need to make sure we do this in the same order as the original API type. Then we’ll set this pattern to be the result of calling client on a proxy for our API:

fetchUserClient :: Int64 -> ClientM User
createUserClient :: User -> ClientM Int64
(fetchUserClient :<|> createUserClient) = client (Proxy :: Proxy UsersAPI)

And that’s it! The Servant library fills in the details for us and implements these functions! We’ll see how we can actually call these functions later in the article.

Setting Up the Tests

We’d like to get to the business of deciding on our test cases and writing them. But first we need to make sure that our tests have a proper environment. This means 3 things. First we need to fetch the connection information for our data stores and API. This means the PGInfo, the RedisInfo, and the ClientEnv we’ll use to call the client functions we wrote. Second, we need to actually migrate our database so it has the proper tables. Third, we need to make sure our server is actually running. Let’s start with the connection information, as this is easy:

import Database (fetchPostgresConnection, fetchRedisConnection)

...

setupTests = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  ...

Now to create our client environment, we’ll need two main things. We’ll need a manager for the network connections and the base URL for the API. Since we’re running the API locally, we’ll use a localhost URL. The default manager from the Network library will work fine for us:

import Network.HTTP.Client (newManager)
import Network.HTTP.Client.TLS (tlsManagerSettings)
import Servant.Client (ClientEnv(..))

main = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  mgr <- newManager tlsManagerSettings
  baseUrl <- parseBaseUrl "http://127.0.0.1:8000"
  let clientEnv = ClientEnv mgr baseUrl

Now we can run our migration, which will ensure that our users table exists:

import Schema (migrateAll)

main = do
  pgInfo <- fetchPostgresConnection
  ...
  runStdoutLoggingT $ withPostgresqlConn pgInfo $ \dbConn ->
    runReaderT (runMigrationSilent migrateAll) dbConn

Last of all, we’ll start our server with runServer from our API module. We’ll fork this off to a separate thread, as otherwise it will block the test thread! We’ll wait for a second afterward to make sure it actually loads before the tests run (there are less hacky ways to do this of course). But then we’ll return all the important information we need, and we're done with test setup:

main :: IO (PGInfo, RedisInfo, ClientEnv, ThreadID)
main = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  mgr <- newManager tlsManagerSettings
  baseUrl <- parseBaseUrl "http://127.0.0.1:8000"
  let clientEnv = ClientEnv mgr baseUrl
  runStdoutLoggingT $ withPostgresqlConn pgInfo $ \dbConn ->
    runReaderT (runMigrationSilent migrateAll) dbConn
  threadId <- forkIO runServer
  threadDelay 1000000
  return (pgInfo, redisInfo, clientEnv, serverThreadId)

Organizing our 3 Test Cases

Now that we’re all set up, we can decide on our test cases. We’ll look at 3. First, if we have an empty database and we fetch a user by some arbitrary ID, we’ll expect an error. Further, we should expect that the user does not exist in the database or in the cache, even after calling fetch.

In our second test case, we’ll look at the effects of calling the create endpoint. We’ll save the key we get from this endpoint. Then we’ll verify that this user exists in the database, but NOT in the cache. Finally, our third case will insert the user with the create endpoint and then fetch the user. We’ll expect at the end of this that in fact the user exists in both the database AND the cache.

We organize each of our tests into up to three parts: the “before hook”, the test assertions, and the “after hook”. A “before hook” is some IO code we’ll run that will return particular results to our test assertion. We want to make sure it’s done running BEFORE any test assertions. This way, there’s no interleaving of effects between our test output and the API calls. Each before hook will first make the API calls we want. Then they'll investigate our different databases and determine if certain users exist.

We also want our tests to be database-neutral. That is, the database and cache should be in the same state after the test as they were before. So we’ll also have “after hooks” that run after our tests have finished (if we’ve actually created anything). The after hooks will delete any new entries. This means our before hooks also have to pass the keys for any database entities they create. This way the after hooks know what to delete.

Last of course, we actually need the testing code that assertions about the results. These will be pretty straightforward as we’ll see below.

Test #1

For our first test, we’ll start by making a client call to our API. We use runClientM combined with our clientEnv and the fetchUserClient function. Next, we’ll determine that the call in fact returns an error as it should. Then we’ll add two more lines checking if there’s an entry with the arbitrary ID in our database and our cache. Finally, we return all three boolean values:

beforeHook1 :: ClientEnv -> PGInfo -> RedisInfo -> IO (Bool, Bool, Bool)
beforeHook1 clientEnv pgInfo redisInfo = do
  callResult <- runClientM (fetchUserClient 1) clientEnv
  let throwsError = isLeft (callResult)
  inPG <- isJust <$> fetchUserPG pgInfo 1
  inRedis <- isJust <$> fetchUserRedis redisInfo 1
  return (throwsError, inPG, inRedis)

Now we’ll write our assertion. Since we’re using a before hook returning three booleans, the type of our Spec will be SpecWith (Bool, Bool, Bool). Each it assertion will take this boolean tuple as a parameter, though we’ll only use one for each line.

spec1 :: SpecWith (Bool, Bool, Bool)
spec1 = describe "After fetching on an empty database" $ do
  it "The fetch call should throw an error" $ \(throwsError, _, _) -> throwsError `shouldBe` True
  it "There should be no user in Postgres" $ \(_, inPG, _) -> inPG `shouldBe` False
  it "There should be no user in Redis" $ \(_, _, inRedis) -> inRedis `shouldBe` False

And that’s all we need for the first test! We don’t need an after hook since it doesn’t add anything to our database.

Tests 2 and 3

Now that we’re a little more familiar with how this code works, let’s take a quick look at the next before hook. This time we’ll first try creating our user. If this fails for whatever reason, we’ll throw an error and stop the tests. Then we can use the key to check out if the user exists in our database and Redis. We return the boolean values and the key.

beforeHook2 :: ClientEnv -> PGInfo -> RedisInfo -> IO (Bool, Bool, Int64)
beforeHook2 clientEnv pgInfo redisInfo = do
  userKeyEither <- runClientM (createUserClient testUser) clientEnv
  case userKeyEither of
    Left _ -> error "DB call failed on spec 2!"
    Right userKey -> do 
      inPG <- isJust <$> fetchUserPG pgInfo userKey
      inRedis <- isJust <$> fetchUserRedis redisInfo userKey
      return (inPG, inRedis, userKey)

Now our spec will look similar. This time we expect to find a user in Postgres, but not in Redis.

spec2 :: SpecWith (Bool, Bool, Int64)
spec2 = describe "After creating the user but not fetching" $ do
  it "There should be a user in Postgres" $ \(inPG, _, _) -> inPG `shouldBe` True
  it "There should be no user in Redis" $ \(_, inRedis, _) -> inRedis `shouldBe` False

Now we need to add the after hook, which will delete the user from the database and cache. Of course, we expect the user won’t exist in the cache, but we include this since we’ll need it in the final example:

afterHook :: PGInfo -> RedisInfo -> (Bool, Bool, Int64) -> IO ()
afterHook pgInfo redisInfo (_, _, key) = do
  deleteUserCache redisInfo key
  deleteUserPG pgInfo key

Last, we’ll write one more test case. This will mimic the previous case, except we’ll throw in a call to fetch in between. As a result, we expect the user to be in both Postgres and Redis:

beforeHook3 :: ClientEnv -> PGInfo -> RedisInfo -> IO (Bool, Bool, Int64)
beforeHook3 clientEnv pgInfo redisInfo = do
  userKeyEither <- runClientM (createUserClient testUser) clientEnv
  case userKeyEither of
    Left _ -> error "DB call failed on spec 3!"
    Right userKey -> do 
      _ <- runClientM (fetchUserClient userKey) clientEnv 
      inPG <- isJust <$> fetchUserPG pgInfo userKey
      inRedis <- isJust <$> fetchUserRedis redisInfo userKey
      return (inPG, inRedis, userKey)

spec3 :: SpecWith (Bool, Bool, Int64)
spec3 = describe "After creating the user and fetching" $ do
  it "There should be a user in Postgres" $ \(inPG, _, _) -> inPG `shouldBe` True
  it "There should be a user in Redis" $ \(_, inRedis, _) -> inRedis `shouldBe` True

And it will use the same after hook as case 2, so we’re done!

Hooking in and running the tests

The last step is to glue all our pieces together with hspec, before, and after. Here’s our main function, which also kills the thread running the server once it’s done:

main :: IO ()
main = do
  (pgInfo, redisInfo, clientEnv, tid) <- setupTests
  hspec $ before (beforeHook1 clientEnv pgInfo redisInfo) spec1
  hspec $ before (beforeHook2 clientEnv pgInfo redisInfo) $ after (afterHook pgInfo redisInfo) $ spec2
  hspec $ before (beforeHook3 clientEnv pgInfo redisInfo) $ after (afterHook pgInfo redisInfo) $ spec3
  killThread tid 
  return ()

And now our tests should pass!

After fetching on an empty database
  The fetch call should throw an error
  There should be no user in Postgres
  There should be no user in Redis

Finished in 0.0410 seconds
3 examples, 0 failures

After creating the user but not fetching
  There should be a user in Postgres
  There should be no user in Redis

Finished in 0.0585 seconds
2 examples, 0 failures

After creating the user and fetching
  There should be a user in Postgres
  There should be a user in Redis

Finished in 0.0813 seconds
2 examples, 0 failures

Using Docker

So when I say, “the tests pass”, they now work on my system. But if you were to clone the code as is and try to run them, you would get failures. The tests depend on Postgres and Redis, so if you don't have them running, they fail! It is quite annoying to have your tests depend on these outside services. This is the weakness of devising our tests as we have. It increases the on-boarding time for anyone coming into your codebase. The new person has to figure out which things they need to run, install them, and so on.

So how do we fix this? One answer is by using Docker. Docker allows you to create containers that have particular services running within them. This spares you from worrying about the details of setting up the services on your local machine. Even more important, you can deploy a docker image to your remote environments. So develop and prod will match your local system. To setup this process, we’ll create a description of the services we want running on our Docker container. We do this with a docker-compose file. Here’s what ours looks like:

version: '2'

services:
  postgres:
    image: postgres:9.6
    container_name: prod-haskell-series-postgres
    ports:
      - "5432:5432"

  redis:
    image: redis:4.0
    container_name: prod-haskell-series-redis
    ports:
      - "6379:6379"

Then, you can start these services for your Docker machines with docker-compose up. Granted, you do have to install and run Docker. But if you have several different services, this is a much easier on-boarding process. Better yet, the "compose" file ensures everyone uses the same versions of these services.

Even with this container running, the tests will still fail! That’s because you also need the tests themselves to be running on your Docker cluster. But with Stack, this is easy! We’ll add the following flag to our stack.yaml file:

docker:
  enable: true

Now, whenever you build and test your program, you will do so on Docker. The first time you do this, Docker will need to set everything up on the container. This means it will have to download Stack and ALL the different packages you use. So the first run will take a while. But subsequent runs will be normal. So after all that finishes, NOW the tests should work!

Conclusion

Testing integrated systems is hard. We can try mocking out the behavior of external services. But this can lead to a test representation of our program that isn’t faithful to the production system. But using the before and after hooks from Hspec is a great way make sure all your external events happen first. Then you can pass those results to simpler test assertions.

When it comes time to run your system, it helps if you can bring up all your external services with one command! Docker allows you to do this by listing the different services in the docker-compose file. Then, Stack makes it easy to run your program and tests on a docker container, so you can use the services!

Stack is the key to all this integration. If you’ve never used Stack before, you should check out our free mini-course. It will teach you all the basics of organizing a Haskell project using Stack.

If this is your first exposure to Haskell, I’ve hopefully convinced of some of its awesome possibilities! Take a look at our Getting Started Checklist and get learning!

Read More
James Bowen James Bowen

A Cache is Fast: Enhancing our API with Redis

In the last couple weeks we’ve used Persistent to store a User type in a Postgresql database. Then we were able to use Servant to create a very simple API that exposed this database to the outside world. This week, we’re going to look at how we can improve the performance of our API using a Redis cache.

One cannot overstate the importance of caching in both software and hardware. There's a hierarchy of memory types from registers, to RAM, to the File system, to a remote database. Accessing each of these gets progressively slower (by orders of magnitude). But the faster means of storage are more expensive, so we can’t always have as much as we'd like.

But memory usage operates on a very important principle. When we use a piece of memory once, we’re very likely to use it again in the near-future. So when we pull something out of long-term memory, we can temporarily store it in short-term memory as well. This way when we need it again, we can get it faster. After a certain point, that item will be overwritten by other more urgent items. This is the essence of caching.

Redis 101

Redis is an application that allows us to create a key-value store of items. It functions like a database, except it only uses these keys. It lacks the sophistication of joins, foreign table references and indices. So we can’t run the kinds of sophisticated queries that are possible on an SQL database. But we can run simple key lookups, and we can do them faster. In this article, we'll use Redis as a short-term cache for our user objects.

For this article, we've got one main goal for cache integration. Whenever we “fetch” a user using the GET endpoint in our API, we want to store that user in our Redis cache. Then the next time someone requests that user from our API, we'll grab them out of the cache. This will save us the trouble of making a longer call to our Postgres database.

Connecting to Redis

Haskell's Redis library has a lot of similarities to Persistent and Postgres. First, we’ll need some sort of data that tells us where to look for our database. For Postgres, we used a simple ConnectionString with a particular format. Redis uses a full data type called ConnectInfo.

data ConnectInfo = ConnectInfo
  { connectHost :: HostName -- String
  , connectPort :: PortId   -- (Can just be a number)
  , connectAuth :: Maybe ByteString
  , connectDatabase :: Integer
  , connectMaxConnection :: Int
  , connectMaxIdleTime :: NominalDiffTime
  }

This has many of the same fields we stored in our PG string, like the host IP address, and the port number. The rest of this article assumes you are running a local Redis instance at port 6379. This means we can use defaultConnectInfo. As always, in a real system you’d want to grab this information out of a configuration, so you’d need IO.

fetchRedisConnection :: IO ConnectInfo
fetchRedisConnection = return defaultConnectInfo

With Postgres, we used withPostgresqlConn to actually connect to the database. With Redis, we do this with the connect function. We'll get a Connection object that we can use to run Redis actions.

connect :: ConnectInfo -> IO Connection

With this connection, we simply use runRedis, and then combine it with an action. Here’s the wrapper runRedisAction we’ll write for that:

runRedisAction :: ConnectInfo -> Redis a -> IO a
runRedisAction redisInfo action = do
  connection <- connect redisInfo
  runRedis connection action

The Redis Monad

Just as we used the SqlPersistT monad with Persist, we’ll use the Redis monad to interact with our Redis cache. Our API is simple, so we’ll stick to three basic functions. The real types of these functions are a bit more complicated. But this is because of polymorphism related to transactions, and we won't be using those.

get :: ByteString -> Redis (Either x (Maybe ByteString))
set :: ByteString -> ByteString -> Redis (Either x ())
setex :: ByteString -> ByteString -> Int -> Redis (Either x ())

Redis is a key-value store, so everything we set here will use ByteString items. But once we’ve done that, these functions are all we need to use. The get function takes a ByteString of the key and delivers the value as another ByteString. The set function takes both the serialized key and value and stores them in the cache. The setex function does the same thing as set except that it also sets an expiration time for the item we’re storing.

Expiration is a very useful feature to be aware of, since most relational databases don’t have this. The nature of a cache is that it’s only supposed to store a subset of our information at any given time. If we never expire or delete anything, it might eventually store our whole database. That would defeat the purpose of using a cache! It's memory footprint should remain low compared to our database. So we'll use setex in our API.

Saving a User in Redis

So now let’s move on to the actions we’ll actually use in our API. First, we’ll write a function that will actually store a key-value pair of an Int64 key and the User in the database. Here’s how we start:

cacheUser :: ConnectInfo -> Int64 -> User -> IO ()
cacheUser redisInfo uid user = runRedisAction redisInfo $ setex ??? ??? ???

All we need to do now is convert our key and our value to ByteString values. We'll keep it simple and use Data.ByteString.Char8 combined with our Show and Read instances. Then we’ll create a Redis action using setex and expire the key after 3600 seconds (one hour).

import Data.ByteString.Char8 (pack, unpack)

...

cacheUser :: ConnectInfo -> Int64 -> User -> IO ()
cacheUser redisInfo uid user = runRedisAction redisInfo $ void $ 
  setex (pack . show $ uid) 3600 (pack . show $ user)

(We use void to ignore the result of the Redis call).

Fetching from Redis

Fetching a user is a similar process. We’ll take the connection information and the key we’re looking for. The action we’ll create uses the bytestring representation and calls get. But we can’t ignore the result of this call like we could before! Retrieving anything gives us Either e (Maybe ByteString). A Left response indicates an error, while Right Nothing indicates the key doesn’t exist. We’ll ignore the errors and treat the result as Maybe User though. If any error comes up, we’ll return Nothing. This means we run a simple pattern match:

fetchUserRedis :: ConnectInfo -> Int64 -> IO (Maybe User)
fetchUserRedis redisInfo uid = runRedisAction redisInfo $ do
  result <- Redis.get (pack . show $ uid)
  case result of
    Right (Just userString) -> return $ Just (read . unpack $ userString)
    _ -> return Nothing

If we do find something for that key, we’ll read it out of its ByteString format and then we’ll have our final User object.

Applying this to our API

Now that we’re all set up with our Redis functions, we have the update the fetchUsersHandler to use this cache. First, we now need to pass the Redis connection information as another parameter. For ease of reading, we’ll refer to these using type synonyms (PGInfo and RedisInfo) from now on:

type PGInfo = ConnectionString
type RedisInfo = ConnectInfo

…

fetchUsersHandler :: PGInfo -> RedisInfo -> Int64 -> Handler User
fetchUsersHandler pgInfo redisInfo uid = do
  ...

The first thing we’ll try is to look up the user by their ID in the Redis cache. If the user exists, we’ll immediately return that user.

fetchUsersHandler :: PGInfo -> RedisInfo -> Int64 -> Handler User
fetchUsersHandler pgInfo redisInfo uid = do
  maybeCachedUser <- liftIO $ fetchUserRedis redisInfo uid
  case maybeCachedUser of
    Just user -> return user
    Nothing -> do
      ...

If the user doesn’t exist, we’ll then drop into the logic of fetching the user in the database. We’ll replicate our logic of throwing an error if we find that user doesn’t actually exist. But if we find the user, we need one more step. Before we return it, we should call cacheUser and store it for the future.

fetchUsersHandler :: PGInfo -> RedisInfo -> Int64 -> Handler User
fetchUsersHandler pgInfo redisInfo uid = do
  maybeCachedUser <- liftIO $ fetchUserRedis redisInfo uid
  case maybeCachedUser of
    Just user -> return user
    Nothing -> do
      maybeUser <- liftIO $ fetchUserPG pgInfo uid
      case maybeUser of
        Just user -> liftIO (cacheUser redisInfo uid user) >> return user
        Nothing -> Handler $ (throwE $ err401 { errBody = "Could not find user with that ID" })

Since we changed our type signature, we’ll have to make a few other updates as well, but these are quite simple:

usersServer :: PGInfo -> RedisInfo -> Server UsersAPI
usersServer pgInfo redisInfo =
  (fetchUsersHandler pgInfo redisInfo) :<|> 
  (createUserHandler pgInfo)


runServer :: IO ()
runServer = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  run 8000 (serve usersAPI (usersServer pgInfo redisInfo))

And that’s it! We have a functioning cache with expiring entries. This means that repeated queries to our fetch endpoint should be much faster!

Conclusion

Caching is a vitally important way that we can write software that is often much faster for our users. Redis is a key-value store that we can use as a cache for our most frequently used data. We can use it as an alternative to forcing every single API call to hit our database. In Haskell, the Redis API requires everything to be a ByteString. So we have to deal with some logic surrounding encoding and decoding. But otherwise it operates in a very similar way to Persistent and Postgres.

Be sure to take a look at this code on Github! There’s a redis branch for this article. It includes all the code samples, including things I skipped over like imports!

We’re starting to get to the point where we’re using a lot of different libraries in our Haskell application! It pays to know how to organize everything, so package management is vital! I tend to use Stack for all my package management. It makes it quite easy to bring all these different libraries together. If you want to learn how to use Stack, check out our free Stack mini-course!

If you’ve never learned Haskell before, you should try it out! Download our Getting Started Checklist!

Read More
James Bowen James Bowen

Serve it up with Servant!

Last week we began our series on production Haskell techniques by learning about Persistent. We created a schema that contained a single User type that we could store in a Postgresql database. We examined a couple functions allowing us to make SQL queries about these users.

This week, we’ll see how we can expose this database to the outside world using an API. We’ll construct our API using the Servant library. Servant involves some advanced type level constructs, so there’s a lot to wrap your head around. There are definitely simpler approaches to HTTP servers than what Servant uses. But I’ve found that the power Servant gives us is well worth the effort.

This article will give a brief overview on Servant. But if you want a more in-depth introduction, you should check out my talk from Bayhac last spring! That talk was more exhaustive about the different combinators you can use in your APIs. It also showed authentication techniques, client functions and documentation. You can also check out the slides and code for that presentation!

Also, take a look at the servant branch on the Github repo for this project to see all the code for this article!

Defining our API

The first step in writing an API for our user database is to decide what the different endpoints are. We can decide this independent of what language or library we’ll use. For this article, our API will have two different endpoints. The first will be a POST request to /users. This request will contain a “user” definition in its body, and the result will be that we’ll create a user in our database. Here’s a sample of what this might look like:

POST /users
{
  userName : “John Doe”,
  userEmail : “john@doe.com”,
  userAge : 29,
  userOccupation: “Teacher”
}

It will then return a response containing the database key of the user we created. This will allow any clients to fetch the user again. The second endpoint will use the ID to fetch a user by their database identifier. It will be a GET request to /users/:userid. So for instance, the last request might have returned us something like 16. We could then do the following:

GET /users/16

And our response would look like the request body from above.

An API as a Type

So we’ve got our very simple API. How do we actually define this in Haskell, and more specifically with Servant? Well, Servant does something pretty unique (as far I’ve researched). In Servant we define our API by using a type. Our type will include sub-types for each of the endpoints of our API. We combine the different endpoints by using the (:<|>) operator. I'll sometimes refer to this as “E-plus”, for “endpoint-plus”. This is a type operator, like some of the operators we saw with dependent types and tensor flow. Here’s the blueprint of our API:

type UsersAPI = 
  fetchEndpoint
  :<|> createEndpoint

Now let's define what we mean by fetchEndpoint and createEndpoint. Endpoints combine different combinators that describe different information about the endpoint. We link combinators together with the (:>) operator, which I call “C-plus” (combinator plus). Here’s what our final API looks like. We’ll go through what each combinator means in the next section:

type UsersAPI =
       “users” :> Capture “userid” Int64 :> Get ‘[JSON] User
  :<|> “users” :> ReqBody ‘[JSON] User :> Post ‘[JSON] Int64

Different Combinators

Both of these endpoints have three different combinators. Let’s start by examining the fetch endpoint. It starts off with a string combinator. This is a path component, allowing us to specify what url extension the caller should use to hit to endpoint. We can use this combinator multiple times to have a more complicated path for the endpoint. If we instead wanted this endpoint to be at /api/users/:userid then we’d change it to:

“api” :> “users” :> Capture “userid” Int64 :> Get ‘[JSON] User

The second combinator (Capture) allows us to get a value out of the URL itself. We give this value a name and then we supply a type parameter. We won't have to do any path parsing or manipulation ourselves. Servant will handle the tricky business of parsing the URL and passing us an Int64. If you want to use your own custom class as a piece of HTTP data, that's not too difficult. You’ll just have to write an instance of the FromHttpApiData class. All the basic types like Int64 already have instances.

The final combinator itself contains three important pieces of information for this endpoint. First, it tells us that this is in fact a GET request. Second, it gives us the list of content-types that are allowable in the response. This is a type level list of content formats. Each type in this list must have different classes for serialization and deserialization of our data. We could have used a more complicated list like ’[JSON, PlainText, OctetStream]. But for the rest of this article, we’ll just use JSON. This means we'll use the ToJSON and FromJSON typeclasses for serialization.

The last piece of this combinator is the type our endpoint returns. So a successful request will give the caller back a response that contains a User in JSON format. Notice this isn’t a Maybe User. If the ID is not in our database, we’ll return a 401 error to indicate failure, rather than returning Nothing.

Our second endpoint has many similarities. It uses the same string path component. Then its final combinator is the same except that it indicates it is a POST request instead of a GET request. The second combinator then tells us what we can expect the request body to look like. In this case, the request body should contain a JSON representation of a User. It requires a list of acceptable content types, and then the type we want, like the Get and Post combinators.

That completes the “definition” of our API. We’ll need to add ToJSON and FromJSON instances of our User type in order for this to function. You can take a look at those on Github, and check out this article for more details on creating those instances!

Writing Handlers

Now that we’ve defined the type of our API, we need to write handler functions for each endpoint. This is where Servant’s awesomeness kicks in. We can map each endpoint up to a function that has a particular type based on the combinators in the endpoint. So, first let’s remember our endpoint for fetching a user:

“users” :> Capture “userid” Int64 :> Get ‘[JSON] User

The string path component doesn’t add any arguments to our function. The Capture component will result in a parameter of type Int64 that we’ll need in our function. Then the return type of our function should be User. This almost completely defines the type signature of our handler. We'll note though that it needs to be in the Handler monad. So here’s what it’ll look like:

fetchUsersHandler :: Int64 -> Handler User
...

Servant can also look at the type for our create endpoint:

“users” :> ReqBody ‘[JSON] User :> Post ‘[JSON] Int64

The parameter for a ReqBody parameter is just the type argument. So it will resolve the endpoint into the handler type:

createUserHandler :: User -> Handler Int64
...

Now, we’ll need to be able to access our Postgres database through both these handlers. So they’ll each get an extra parameter referring to the ConnectionString. We’ll pass that from our code so that by the time Servant is resolving the types, the parameter is accounted for:

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
createUserHandler :: ConnectionString -> User -> Handler Int64

Before we go any further, we should discuss the Handler monad. This is a wrapper around the monad ExceptT ServantErr IO. In other words, each of these requests might fail. To make it fail, we can throw errors of type ServantErr. Then of course we can also call IO functions, because these are network operations.

Before we implement these functions, let’s first write a couple simple helpers. These will use the runAction function from last week’s article to run database actions:

fetchUserPG :: ConnectionString -> Int64 -> IO (Maybe User)
fetchUserPG connString uid = runAction connString (get (toSqlKey uid))

createUserPG :: ConnectionString -> User -> IO Int64
createUserPG connString user = fromSqlKey <$> runAction connString (insert user)

For completeness (and use later in testing), we’ll also add a simple delete function. We need the where clause for type inference:

deleteUserPG :: ConnectionString -> Int64 -> IO ()
deleteUserPG connString uid = runAction connString (delete userKey)
  where
    userKey :: Key User
    userKey = toSqlKey uid

Now from our Servant handlers, we’ll call these two functions. This will completely cover the case of the create endpoint. But we’ll need a little bit more logic for the fetch endpoint. Since our functions are in the IO monad, we have to lift them up to Handler.

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
fetchUserHandler connString uid = do
  maybeUser <- liftIO $ fetchUserPG connString uid
  ...

createUserHandler :: ConnectionString -> User -> Handler Int64
createuserHandler connString user = liftIO $ createUserPG connString user

To complete our fetch handler, we need to account for a non-existent user. Instead of making the type of the whole endpoint a Maybe, we’ll throw a ServantErr in this case. We can use one of the built-in Servant error functions, which correspond to normal error codes. Then we can update the body. In this case, we’ll throw a 401 error. Here’s how we do that:

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
fetchUserHandler connString uid = do
  maybeUser <- lift $ fetchUserPG connString uid
  case maybeUser of
    Just user -> return user
    Nothing -> Handler $ (throwE $ err401 { errBody = “Could not find user with ID: “ ++ (show uid)})

createUserHandler :: ConnectionString -> User -> Handler Int64
createuserHandler connString user = lift $ createUserPG connString user

And that’s it! We're done with our handler functions!

Combining it All into a Server

Our next step is to create an object of type Server over our API. This is actually remarkably simple. When we defined the original type, we combined the endpoints with the (:<|>) operator. To make our Server, we do the same thing but with the handler functions:

usersServer :: ConnectionString -> Server UsersAPI
usersServer pgInfo = 
  (fetchUsersHandler pgInfo) :<|> 
  (createUserHandler pgInfo)

And Servant does all the work of ensuring that the type of each endpoint matches up with the type of the handler! It’s pretty awesome. Suppose we changed the type of our fetchUsersHandler so that it took a Key User instead of an Int64. We’d get a compile error:

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
…

-- Compile Error!
• Couldn't match type ‘Key User’ with ‘Int’
      Expected type: Server UsersAPI
        Actual type: (Key User -> Handler User)
                     :<|> (User -> Handler Int64)

There's now a mismatch between our API definition and our handler definition. So Servant knows to throw an error! The one issue is that the error messages can be rather difficult to interpret sometimes. This is especially the case when your API becomes very large! The “Actual type” section of the above error will become massive! So always be careful when changing your endpoints! Frequent compilation is your friend!

Building the Application

The final piece of the puzzle is to actually build an Application object out of our server. The first step of this process is to create a Proxy for our API. Remember that our API is a type, and not a term. But a Proxy allows us to represent this type at the term level. The concept is a little complicated, but the code is not!

import Data.Proxy

…

usersAPI :: Proxy UsersAPI
usersAPI = Proxy :: Proxy UsersAPI

Now we can make our runnable Application like so (assuming we have a Postgres connection):

serve usersAPI (usersServer connString)

We’ll run this server from port 8000 by using the run function, again from Network.Wai. (See Github for a full list of imports). We’ll fetch our connection string, and then we’re good to go!

runServer :: IO ()
runServer = do
  pgInfo <- fetchPostgresConnection
  run 8000 (serve usersAPI (usersServer pgInfo))

Conclusion

The Servant library offers some truly awesome possibilities. We’re able to define a web API at the type level. We can then define handler functions using the parameters the endpoints expect. Servant handles all the work of marshalling back and forth between the HTTP request and the native Haskell types. It also ensures a match between the endpoints and the handler function types!

If you want to see even more of the possibilities that Servant offers, you should watch my talk from Bayhac. It goes through some more advanced concepts like authentication and client side functions. You can get the slides and all the code examples for that talk here.

If you’ve never tried Haskell before, there’s no time like the present to start! Download our Getting Started Checklist for some tools to help start your Haskell journey!

Read More
James Bowen James Bowen

Trouble with Databases? Persevere with Persistent!

Our recent series at Monday Morning Haskell focused on machine learning. In particular, we did a deep dive on the Haskell Tensor Flow library. While AI is huge area indeed, it doesn't account for the bulk of day-to-day work. To build even a basic production system, there are a multitude of simpler tasks. In our newest series, we’ll be learning a host of different libraries to perform these tasks!

In this first article we’ll discuss Persistent. Many libraries allow you to make a quick SQL call. But Persistent does much more than that. With Persistent, you can link your Haskell types to your database definition. You can also make type-safe queries to save yourself the hassle of decoding data. All in all, it's a very cool system.

All the code for this series will be on Github! To follow along with this article, take a look at the persistent branch.

Our Basic Type

We’ll start by considering a simple user type that looks like this:

data User = User
  { userName :: Text
  , userEmail :: Text
  , userAge :: Int
  , userOccupation :: Text
  }

Imagine we want to store objects of this type in an SQL database. We’ll first need to define the table to store our users. We could do this with a manual SQL command or through an editor, but regardless, the process will be error prone. The command would look something like this:

create table users (
  name varchar(100),
  email varchar(100),
  age bigint,
  occupation varchar(100)
)

When we do this, there's nothing linking our Haskell data type to the table structure. If we update the Haskell code, we have to remember to update the database. And this means writing another error-prone command.

From our Haskell program, we’ll also want to make SQL queries based on the structure of the user. We could write out these raw commands and execute them, but the same issues apply. This method is error prone and not at all type-safe. Persistent helps us solve these problems.

Persistent and Template Haskell

We can get these bonuses from Persistent without all that much extra code! To do this, we’re going to use Template Haskell (TH). We’ve seen TH once in the past when we were deriving lenses and prisms for different data types. On that occasion we noted a few pros and cons of TH. It does allow us to avoid writing some boilerplate code. But it will make our compile times longer as well. It will also make our code less accessible to inexperienced Haskellers. With lenses though, it only saved us a few dozen lines of total code. With Persistent, TH generates a lot more code, so the pros definitely outweigh the cons.

When we created lenses with TH, we used a simple declaration makeLenses. Here, we’ll do something a bit more complicated. We’ll use a language construct called a “quasi-quoter”. This is a block of code that follows some syntax designed by the programmer or in a library, rather than normal Haskell syntax. It is often used in libraries that do some sort of foreign function interface. We delimit a quasi-quoter by a combination of brackets and pipes. Here’s what the Template Haskell call looks like. The quasi-quoter is the final argument:

import qualified Database.Persist.TH as PTH

PTH.share [PTH.mkPersist PTH.sqlSettings, PTH.mkMigrate "migrateAll"] [PTH.persistLowerCase|

|]

The share function takes a list of settings and then the quasi-quoter itself. It then generates the necessary Template Haskell for our data schema. Within this section, we’ll define all the different types our database will use. We notate certain settings about how those types. In particular we specify sqlSettings, so everything we do here will focus on an SQL database. More importantly, we also create a migration function, migrateAll. After this Template Haskell gets compiled, this function will allow us to migrate our DB. This means it will create all our tables for us!

But before we see this in action, we need to re-define our user type. Instead of defining User in the normal Haskell way, we’re going to define it within the quasi-quoter. Note that this level of Template Haskell requires many compiler extensions. Here’s our definition:

{-# LANGUAGE TemplateHaskell            #-}
{-# LANGUAGE QuasiQuotes                #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE OverloadedStrings          #-}



PTH.share [PTH.mkPersist PTH.sqlSettings, PTH.mkMigrate "migrateAll"] [PTH.persistLowerCase|
  User sql=users
    name Text
    email Text
    age Int
    occupation Text
    UniqueEmail email
    deriving Show Read
|]

There are a lot of similarities to a normal data definition in Haskell. We’ve changed the formatting and reversed the order of the types and names. But you can still tell what’s going on. The field names are all there. We’ve still derived basic instances like we would in Haskell.

But we’ve also added some new directives. For instance, we’ve stated what the table name should be (by default it would be user, not users). We’ve also created a UniqueEmail constraint. This tells our database that each user has to have a unique email. The migration will handle creating all the necessary indices for this to work!

This Template Haskell will generate the normal Haskell data type for us. All fields will have the prefix user and will be camel-cased, as we specified. The compiler will also generate certain special instances for our type. These will enable us to use Persistent's type-safe query functions. Finally, this code generates lenses that we'll use as filters in our queries, as we'll see later.

Entities and Keys

Persistent also has a construct allowing us to handle database IDs. For each type we put in the schema, we’ll have a corresponding Entity type. An Entity refers to a row in our database, and it associates a database ID with the object itself. The database ID has the type SqlKey and is a wrapper around Int64. So the following would look like a valid entity:

import Database.Persist (Entity(..))

sampleUser :: Entity User
sampleUser = Entity (toSqlKey 1) $ User
  { userName = “admin”
  , userEmail = “admin@test.com”
  , userAge = 23
  , userOccupation = “System Administrator”
  }

This nice little abstraction that allows us to avoid muddling our user type with the database ID. This allows our other code to use a more pure User type.

The SqlPersistT Monad

So now that we have the basics of our schema, how do we actually interact with our database from Haskell code? As a specific example, we’ll be accessing a PostgresQL database. This requires the SqlPersistT monad. All the query functions return actions in this monad. The monad transformer has to live on top of a monad that is MonadIO, since we obviously need IO to run database queries.

If we’re trying to make a database query from a normal IO function, the first thing we need is a ConnectionString. This string encodes information about the location of the database. The connection string generally has 4-5 components. It has the host/IP address, the port, the database username, and the database name. So for instance if you’re running Postgres on your local machine, you might have something like:

{-# LANGUAGE OverloadedStrings #-}

import Database.Persist.Postgresql (ConnectionString)

connString :: ConnectionString
connString = “host=127.0.0.1 port=5432 user=postgres dbname=postgres password=password”

Now that we have the connection string, we’re set to call withPostgresqlConn. This function takes the string and then a function requiring a backend:

-- Also various constraints on the monad m
withPostgresqlConn :: (IsSqlBackend backend) => ConnectionString -> (backend -> m a) -> m a

The IsSqlBackend constraint forces us to use a type that conforms to Persistent’s guidelines. The SqlPersistT monad is only a synonym for ReaderT backend. So in general, the only thing we’ll do with this backend is use it as an argument to runReaderT. Once we’ve done this, we can pass any action within SqlPersistT as an argument to run.

import Control.Monad.Logger (runStdoutLoggingT)
import Database.Persist.Postgresql (ConnectionString, withPostgresqlConn, SqlPersistT)

…

runAction :: ConnectionString -> SqlPersistT a ->  IO a
runAction connectionString action = runStdoutLoggingT $ withPostgresqlConn connectionString $ \backend ->
  runReaderT action backend

Note we add in a call to runStdoutLoggingT so that our action can log its results, as Persistent expects. This is necessary whenever we use withPostgresqlConn. Here's how we would run our migration function:

migrateDB :: IO ()
migrateDB = runAction connString (runMigration migrateAll)

This will create the users table, perfectly to spec with our data definition!

Queries

Now let’s wrap up by taking a quick examination of the kinds of queries we can run. The first thing we could do is insert a new user into our database. For this, Persistent has the insert function. When we insert the user, we’ll get a key for that user as a result. Here’s the type signature for insert specified to our particular User type:

insert :: (MonadIO m) => User -> SqlPersistT m (Key User)

Then of course we can also do things in reverse. Suppose we have a key for our user and we want to get it out of the database. We’ll want the get function. Of course this might fail if there is no corresponding user in the database, so we need a Maybe.:

get :: (MonadIO m) => Key User -> SqlPersistT m (Maybe User)

We can use these functions for any type satisfying the PersistRecordBackend class. This is included for free when we use the template Haskell approach. So you can use these queries on any type that lives in your schema.

But SQL allows us to do much more than query with the key. Suppose we want to get all the users that meet certain criteria. We’ll want to use the selectList function, which replicates the behavior of the SQL SELECT command. It takes a couple different arguments for the different ways to run a selection. The two list types look a little complicated, but we’ll examine them in more detail:

selectList 
  :: PersistRecordBackend backend val 
  => [Filter val] 
  -> [SelectOpt val]
  -> SqlPersistT m [val]

As before, the PersistRecordBackend constraint is satisfied by any type in our TH schema. So we know our User type fits. So let’s examine the first argument. It provides a list of different filters that will determine which elements we fetch. For instance, suppose we want all users who are younger than 25 and whose occupation is “Teacher”. Remember the lenses I mentioned that get generated? We’ll create two different filters on this by using these lenses.

selectYoungTeachers :: (MonadIO m, MonadLogger m) => SqlPersistT m [User]
selectYoungTeachers = select [UserAge <. 25, UserOccupation ==. “Teacher”] []

We use the UserAge lens and the UserOccupation lens to choose the fields to filter on. We use a "less-than" operator to state that the age must be smaller than 25. Similarly, we use the ==. operator to match on the occupation. Then we provide an empty list of SelectOpts.

The second list of selection operations provides some other features we might expect in a select statement. First, we can provide an ordering on our returned data. We’ll also use the generated lenses here. For instance, Asc UserEmail will order our list by email. Here's an ordered query where we also limit ourselves to 100 entries. Here’s what that query would look like:

selectYoungTeachers’ :: (MonadIO m) => SqlPersistT m [User]
selectYoungTeachers’ = selectList [UserAge <=. 25, UserOccupation ==. “Teacher”] [Asc UserEmail]

The other types of SelectOpts include limits and offsets. For instance, we can further modify this query to exclude the first 5 users (as ordered by email) and then limit our selection to 100:

selectYoungTeachers' :: (MonadIO m) => SqlPersistT m [Entity User]
selectYoungTeachers' = selectList
  [UserAge <. 25, UserOccupation ==. "Teacher"] [Asc UserEmail, OffsetBy 5, LimitTo 100]

And that’s all there is to making queries that are type-safe and sensible. We know we’re actually filtering on values that make sense for our types. We don’t have to worry about typos ruining our code at runtime.

Conclusion

Persistent gives us some excellent tools for interacting with databases from Haskell. The Template Haskell mechanisms generate a lot of boilerplate code that helps us. For instance, we can migrate our database to the create the correct tables for our Haskell types. We also can perform queries that filter results in a type-safe way. All in all, it’s a fantastic experience.

Never tried Haskell? Do other languages frustrate you with runtime errors whenever you try to run SQL queries? You should give Haskell a try! Download our Getting Started Checklist for some tools that will help you!

If you’re a little familiar with Haskell but aren’t sure how to incorporate libraries like Persistent, you should check out our Stack mini-course. It’ll walk you through the basics of making a simple Haskell program using Stack.

Read More
James Bowen James Bowen

Grenade! Dependently Typed Neural Networks

In the last couple weeks we explored one of the most complex topics I’ve presented on this blog. We examined potential runtime failures that can occur when using Tensor Flow. These included mismatched dimensions and missing placeholders. In an ideal world, we would catch these issues at compile time instead. At its current stage, the Haskell Tensor Flow library doesn’t support that. But we demonstrated that it is possible to add a layer to do this by using dependent types.

Now, I’m still very much of a novice at dependent types, so the solutions I presented were rather clunky. This week I'll show a better example of this concept from a different library. The Grenade library uses dependent types everywhere. It allows us to build verifiably-valid neural networks with extreme concision. So let’s dive in and see what it’s all about!

Shapes and Layers

The first thing to learn with this library is the two concepts of Shapes and Layers. Shapes are best compared to tensors from Tensor Flow, except that they exist at the type level. In Tensor Flow we could build tensors with arbitrary dimensions. Grenade currently only supports up to three dimensions. So the different shape types either start with D1, D2, or D3, depending on the dimensionality of the shape. Then each of these type constructors take a set of natural number parameters. So the following are all valid “Shape” types within Grenade:

D1 5
D2 4 12
D3 8 10 2

The first represents a vector with 5 elements. The second represents a matrix with 4 rows and 12 columns. And the third represents an 8x10x2 matrix (or tensor, if you like). The different numbers represent those values at the type level, not the term level. If this seems confusing, here’s a good tutorial that goes into more depth about the basics of dependent types. The most important idea is that something of type D1 5 can only have 5 elements. A vector of 4 or 6 elements will not type-check.

So now that we know about shapes, let’s examine layers. Layers describe relationships between our shapes. They encapsulate the transformations that happen on our data. The following are all valid layer types:

Relu
FullyConnected 10 20
Convolution 1 10 5 5 1 1

The layer Relu describes a layer that takes in data of any kind of shape and outputs the same shape. In between, it applies the relu activation function to the input data. Since it doesn’t change the shape, it doesn’t need any parameters.

A FullyConnected layer represents the canonical layer of a neural network. It has two parameters, one for the number of input neurons and one for the number of output neurons. In this case, the layer will take 10 inputs and produce 20 outputs.

A Convolution layer represents a 2D convolution like we saw with our MNIST network. This particular example has 1 input feature, 10 output features, uses a 5x5 patch size, and a 1x1 patch offset.

Describing a Network

Now that we have a basic grasp on shapes and layers, we can see how they fit together to create a full network. A network type has two type parameters. The second parameter is a list of the shapes that our data takes at any given point throughout the network. The first parameter is a list of the layers representing the transformations on the data. So let’s say we wanted to describe a very simple network. It will take 4 inputs and produce 10 outputs using a fully connected layer. Then it will perform an Relu activation. This network looks like this:

type SimpleNetwork = Network
  ‘[FullyConnected 4 10, Relu]
  ‘[ ‘D1 4, ‘D1 10, ‘D1 10]

The apostrophes in front of the lists and D1 terms indicated that these are promoted constructors. So they are types instead of terms. To “read” this type, we start with the first data format. We go to each successive data format by applying the transformation layer. So for instance we start with a 4-vector, and transform it into a 10-vector with a fully-connected layer. Then we transform that 10-vector into another 10-vector by applying relu. That’s all there is to it! We could apply another FullyConnected layer onto this that will have 3 outputs like so:

type SimpleNetwork = Network
  ‘[FullyConnected 4 10, Relu, FullyConnected 10 3]
  ‘[ ‘D1 4, ‘D1 10, ‘D1 10, `D1 3]

Let's look at MNIST to see a more complicated example. We'll start with a 28x28 image of data. Then we’ll perform the convolution layer I mentioned above. This gives us a 3-dimensional tensor of size 24x24x10. Then we can perform 2x2 max pooling on this, resulting in a 12x12x10 tensor. Finally, we can apply an Relu layer, which keeps it at the same size:

type MNISTStart = MNISTStart
  ‘[Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu]
  ‘[D2 28 28, D3 24 24 10, D3 12 12 10, D3 12 12 10]

Here’s what a full MNIST example might look like (per the README on the library’s Github page):

type MNIST = Network
    '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu
     , Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu
     , FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
    '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10
     , 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256
     , 'D1 80, 'D1 80, 'D1 10, 'D1 10]

This is a much simpler and more concise description of our network than we can get in Tensor Flow! Let’s examine the ways in which the library uses dependent types to its advantage.

The Magic of Dependent Types

Describing our network as a type seems like a strange idea if you’ve never used dependent types before. But it gives us a couple great perks!

The first major win we get is that it is very easy to generate the starting values of our network. Since it has a specific type, we can let type inference guide us! We don’t need any term level code that is specific to the shape of our network. All we need to do is attach the type signature and call randomNetwork!

randomSimple :: MonadRandom m => m SimpleNetwork
randomSimple = randomNetwork

This will give us all the initial values we need, so we can get going!

The second (and more important) win is that we can’t build an invalid network! Suppose we try to take our simple network and somehow format it incorrectly. For instance, we could say that instead of the input shape being of size 4, it’s of size 7:

type SimpleNetwork = Network
  ‘[FullyConnected 4 10, Relu, FullyConnected 10 3]
  ‘[ ‘D1 7, ‘D1 10, ‘D1 10, `D1 3]
-- ^^ Notice this 7

This will result in a compile error, since there is a mismatch between the layers. The first layer expects an input of 4, but the first data format is of length 7!

Could not deduce (Layer (FullyConnected 4 10) ('D1 7) ('D1 10))
        arising from a use of ‘randomNetwork’
      from the context: MonadRandom m
        bound by the type signature for:
                   randomSimple :: MonadRandom m => m SimpleNetwork
        at src/IrisGrenade.hs:29:1-48

In other words, it notices that the chain from D1 7 to D1 10 using a FullyConnected 4 10 layer is invalid. So it doesn’t let us make this network. The same thing would happen if we made the layers themselves invalid. For instance, we could make the output and input of the two fully-connected layers not match up:

-- We changed the second to take 20 as the number of input elements.
type SimpleNetwork = Network 
  '[FullyConnected 4 10, Relu, FullyConnected 20 3]
  '[ 'D1 4, 'D1 10, 'D1 20, 'D1 3]

…

/Users/jamesbowen/HTensor/src/IrisGrenade.hs:30:16: error:
    • Could not deduce (Layer (FullyConnected 20 3) ('D1 10) ('D1 3))
        arising from a use of ‘randomNetwork’
      from the context: MonadRandom m
        bound by the type signature for:
                   randomSimple :: MonadRandom m => m SimpleNetwork
        at src/IrisGrenade.hs:29:1-48

So Grenade makes our program much safer by providing compile time guarantees about our network's validity. Runtime errors due to dimensionality are impossible!

Training the Network on Iris

Now let’s do a quick run-through of how we actually train this neural network. Readers with a keen eye may have noticed that the SimpleNetwork we’ve built is the same network we used to train the Iris data set. So we’ll do a training run there, using the following steps:

  1. Write the network type and generate a random network from it
  2. Read our input data into a format that Grenade uses
  3. Write a function to run a training iteration.
  4. Run it!

1. Write the Network type and Generate Network

So we've already done this first step for the most part. We’ll adjust the names a little bit though. Note that I’ll include the imports list as an appendix to the post. Also, the code is on the grenade branch of my Haskell Tensor Flow repository in IrisGrenade.hs!

type IrisNetwork = Network 
  '[FullyConnected 4 10, Relu, FullyConnected 10 3]
  '[ 'D1 4, 'D1 10, 'D1 10, 'D1 3]

randomIris :: MonadRandom m => m IrisNetwork
randomIris = randomNetwork

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = do
  initialNetwork <- randomIris
  ...

2. Take in our Input Data

We’ll make use of the readIrisFromFile function we used back when we first did Iris. Then we'll make a dependent type called IrisRow, which uses the S type. This S type is a container for a shape. We want our input data to use D1 4 for the 4 input features. Then our output data should use D1 3 for the three possible categories.

-- Dependent type on the dimensions of the row
type IrisRow = (S ('D1 4), S ('D1 3))

If we have malformed data, the types will not match up, so we’ll need to return a Maybe to ensure this succeeds. Note that we normalize the data by dividing by 8. This puts all the data between 0 and 1 and makes for better training results. Here's how we parse the data:

parseRecord :: IrisRecord -> Maybe IrisRow
parseRecord record = case (input, output) of
  (Just i, Just o) -> Just (i, o)
  _ -> Nothing
  where
    input = fromStorable $ VS.fromList $ float2Double <$>
      [ field1 record / 8.0, field2 record / 8.0, field3 record / 8.0, field4 record / 8.0]
    output = oneHot (fromIntegral $ label record)

Then we incorporate these into our main function:

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = do
  initialNetwork <- randomIris
  trainingRecords <- readIrisFromFile trainingFile
  testRecords <- readIrisFromFile testingFile

  let trainingData = mapMaybe parseRecord (V.toList trainingRecords)
  let testData = mapMaybe parseRecord (V.toList testRecords)

  -- Catch if any were parsed as Nothing
  if length trainingData /= length trainingRecords || length testData /= length testRecords
    then putStrLn "Hmmm there were some problems parsing the data"
    else …

3. Write a Function to Train the Input Data

This is a multi-step process. First we’ll establish our learning parameters. We'll also write a function that will allow us to call the train function on a particular row element:

learningParams :: LearningParameters
learningParams = LearningParameters 0.01 0.9 0.0005

-- Train the network!
trainRow :: LearningParameters -> IrisNetwork -> IrisRow -> IrisNetwork
trainRow lp network (input, output) = train lp network input output

Next we’ll write two more helper functions that will help us test our results. The first will take the network and a test row. It will transform it into the predicted output and the actual output of the network. The second function will take these outputs and reverse the oneHot process to get the label out (0, 1, or 2).

-- Takes a test row, returns predicted output and actual output from the network.
testRow :: IrisNetwork -> IrisRow -> (S ('D1 3), S ('D1 3))
testRow net (rowInput, predictedOutput) = (predictedOutput, runNet net rowInput)

-- Goes from probability output vector to label
getLabels :: (S ('D1 3), S ('D1 3)) -> (Int, Int)
getLabels (S1D predictedLabel, S1D actualOutput) = 
  (maxIndex (extract predictedLabel), maxIndex (extract actualOutput))

Finally we’ll write a function that will take our training data, test data, the network, and an iteration number. It will return the newly trained network, and log some results about how we’re doing. We’ll first take only a sample of our training data and adjust our parameters so that learning gets slower. Then we'll train the network by folding over the sampled data.

run :: [IrisRow] -> [IrisRow] -> IrisNetwork -> Int -> IO IrisNetwork
run trainData testData network iterationNum = do
  sampledRecords <- V.toList <$> chooseRandomRecords (V.fromList trainData)
  -- Slowly drop the learning rate
  let revisedParams = learningParams 
        { learningRate = learningRate learningParams * 0.99 ^ iterationNum}
  let newNetwork = foldl' (trainRow revisedParams) network sampledRecords
  ....

Then we’ll wrap up the function by looking at our test data, and seeing how much we got right!

run :: [IrisRow] -> [IrisRow] -> IrisNetwork -> Int -> IO IrisNetwork
    run trainData testData network iterationNum = do
      sampledRecords <- V.toList <$> chooseRandomRecords (V.fromList trainData)
      -- Slowly drop the learning rate
      let revisedParams = learningParams 
            { learningRate = learningRate learningParams * 0.99 ^ iterationNum}
      let newNetwork = foldl' (trainRow revisedParams) network sampledRecords
      let labelVectors = fmap (testRow newNetwork) testData
      let labelValues = fmap getLabels labelVectors
      let total = length labelValues
      let correctEntries = length $ filter ((==) <$> fst <*> snd) labelValues
      putStrLn $ "Iteration: " ++ show iterationNum
      putStrLn $ show correctEntries ++ " correct out of: " ++ show total
      return newNetwork

4. Run it!

We’ll call this now from our main function, iterating 100 times, and we’re done!

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = do
 ...
  if length trainingData /= length trainingRecords || length testData /= length testRecords
    then putStrLn "Hmmm there were some problems parsing the data"
    else foldM_ (run trainingData testData) initialNetwork [1..100]

Comparing to Tensor Flow

So now that we’ve looked at a different library, we can consider how it stacks up against Tensor Flow. So first, the advantages. Grenade's main advantage is that it provides dependent type facilities. This means it is more difficult to write incorrect programs. The basic networks you build are guaranteed to have the correct dimensionality. Additionally, it does not use a “placeholders” system, so you can avoid those kinds of errors too. This means you're likely to have fewer runtime bugs using Grenade.

Concision is another major strong point. The training code got a bit involved when translating our data into Grenade's format. But it’s no more complicated than Tensor Flow. When it comes down to the exact definition of the network itself, we do this in only a few lines with Grenade. It’s complicated to understand what those lines mean if you are new to dependent types. But after seeing a few simple examples you should be able to follow the general pattern.

Of course, none of this means that Tensor Flow is without its advantages. As we saw a couple weeks ago, it is not too difficult to add very thorough logging to your Tensor Flow program. The Tensor Board application will then give you excellent visualizations of this data. It is somewhat more difficult to get intermediate log results with Grenade. There is not too much transparency (that I have found at least) into the inner values of the network. The network types are composable though. So it is possible to get intermediate steps of your operation. But if you break your network into different types and stitch them together, you will remove some of the concision of the network.

Also, Tensor Flow also has a much richer ecosystem of machine learning tools to access. Grenade is still limited to a subset of the most common machine learning layers, like convolution and max pooling. Tensor Flow’s API allows approaches like support vector machines and linear models. So Tensor Flow offers you more options.

One question I may explore in a future article would be to compare the performance of the two libraries. My suspicion is that Tensor Flow is faster due to how it gets all its math down to the C-level. But I’m not too familiar yet with HMatrix (which Grenade depends on for its math) and its efficiency. So I could definitely be wrong.

Conclusion

Grenade provides some truly awesome facilities for building a concise neural network. A Grenade program can demonstrate at compile time that the network is well formed. It also allows an incredibly concise way to define what layers your neural network has. It doesn’t have the Google level support that Tensor Flow does. So it lacks many cool features like logging and visualizations. But it is quite a neat library for its scope. One thing I haven’t mentioned is its mechanics for Generative/Adversarial networks. I’d definitely like to try that out soon!

Grenade is a simpler library to incorporate into Stack compared to Tensor Flow. If you want to compare the two, you should check out our Haskell Tensor Flow guide so you can install TF and get started!

If you’ve never written a line of Haskell before, never fear! Download our Getting Started Checklist for some free resources to start your Haskell education!

Appendix: Compiler Extensions and Imports

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GADTs #-}

import           Control.Monad (foldM_)
import           Control.Monad.Random (MonadRandom)
import           Control.Monad.IO.Class (liftIO)
import           Data.Foldable (foldl')
import           Data.Maybe (mapMaybe)
import qualified Data.Vector.Storable as VS
import qualified Data.Vector as V
import           GHC.Float (float2Double)
import           Grenade
import           Grenade.Core.LearningParameters (LearningParameters(..))
import           Grenade.Core.Shape (fromStorable)
import           Grenade.Utils.OneHot (oneHot)
import           Numeric.LinearAlgebra (maxIndex)
import           Numeric.LinearAlgebra.Static (extract)

import           Processing (IrisRecord(..), readIrisFromFile, chooseRandomRecords)
Read More
James Bowen James Bowen

Checking it's all in Place: Placeholders and Dependent Types

Last week we dove into the world of dependent types. We linked tensors with their shapes at the type level. This gave our program some extra type safety and allowed us to avoid certain runtime errors.

This week, we’re going to solve another runtime conundrum: missing placeholders. We’ll add some more dependent type machinery to ensure we've plugged in all the necessary placeholders! But we’ll see this is not as straightforward as shapes.

To follow along with the code in this article, take a look at this branch on my Haskell Tensor Flow Github repository. All the code for this article is in DepShape.hs. As usual, I've listed the necessary compiler extensions and imports at the bottom of this article. If you want to run the code yourself, you'll have to get Haskell and Tensor Flow running first. Take a look at our Haskell Tensor Flow guide for that!

Now to start, let’s remind ourselves what placeholders are in Tensor Flow and how we use them.

Placeholder Review

Placeholders represent tensors that can have different values on different application runs. This is often the case when we’re training on different samples of data. Here’s our very simple example in Python. We’ll create a couple placeholder tensors by providing their shapes and no values. Then when we actually run the session, we’ll provide a value for each of those tensors.

node1 = tf.placeholder(tf.float32)
node2 = tf.placeholder(tf.float32)
adderNode = tf.add(node1, node2)
sess = tf.Session()
result1 = sess.run(adderNode, {node1: 3, node2: 4.5 })

The weakness here is that there’s nothing forcing us to provide values for those tensors! We could try running our program without them and we’ll get a runtime crash:

...
sess = tf.Session()
result1 = sess.run(adderNode)
print(result1)
…

Terminal Output:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype float
   [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Unfortunately, the Haskell Tensor Flow library doesn’t actually do any better here. When we want to fill in placeholders, we provide a list of “feeds”. But our program will still compile even if we pass an empty list! We’ll encounter similar runtime errors:

(node1 :: Tensor Value Float) <- placeholder [1]
(node2 :: Tensor Value Float) <- placeholder [1]
let adderNode = node1 `add` node2
let runStep = \node1Feed node2Feed -> runWithFeeds [] adderNode
runStep (encodeTensorData [1] input1) (encodeTensorData [1] input2)
…

Terminal Output:

TensorFlowException TF_INVALID_ARGUMENT "You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [1]\n\t [[Node: Placeholder_1 = Placeholder[dtype=DT_FLOAT, shape=[1], _device=\"/job:localhost/replica:0/task:0/cpu:0\"]()]]"

In the Iris and MNIST examples, we bury the call to runWithFeeds within our neural network API. We only provide a Model object. This model object forces us to provide the expected input and output tensors. So anyone using our model wouldn't make a manual runWithFeeds call.

data Model = Model
  { train :: TensorData Float
          -> TensorData Int64
          -> Session ()
  , errorRate :: TensorData Float
              -> TensorData Int64
              -> SummaryTensor
              -> Session (Float, ByteString)
  }

This isn’t a bad solution! But it’s interesting to see how we can push the envelope with dependent types, so let’s try that!

Adding More “Safe” Types

The first step we’ll take is to augment Tensor Flow’s TensorData type. We’ll want it to have shape information like SafeTensor and SafeShape. But we’ll also attach a name to each piece of data. This will allow us to identify which tensor to substitute the data in for. At the type level, we refer to this name as a Symbol.

data SafeTensorData a (n :: Symbol) (s :: [Nat]) where
  SafeTensorData :: (TensorType a) => TensorData a -> SafeTensorData a n s

Next, we’ll need to make some changes to our SafeTensor type. First, each SafeTensor will get a new type parameter. This parameter refers to a mapping of names (symbols) to shapes (which are still lists of naturals). We'll call this a placeholder list. So each tensor will have type-level information for the placeholders it depends on. Each different placeholder has a name and a shape.

data SafeTensor v a (s :: [Nat]) (p :: [(Symbol, [Nat])]) where
  SafeTensor :: (TensorType a) => Tensor v a -> SafeTensor v a s p

Now, recall when we substituted for placeholders, we used a list of feeds. But this list had no information about the names or dimensions of its feeds. Let's create a new type containing the different elements we need for our feeds. It should also contain the correct type information about the placeholder list. The first step of to define the type so that it has the list of placeholders it contains, like the SafeTensor.

data FeedList (pl :: [(Symbol, [Nat])]) where

This structure will look like a linked list, like our SafeShape. Thus we’ll start by defining an “empty” constructor:

data FeedList (pl :: [(Symbol, [Nat])]) where
  EmptyFeedList :: FeedList '[]

Now we’ll add a “Cons”-like constructor by creating yet another type operator :--:. Each “piece” of our linked list will contain two different items. First, the tensor we are substituting for. Next, it will have the data we’ll be using for the substitution. We can use type parameters to force their shapes and data types to match. Then we need the resulting placeholder type. We have to append the type-tuple containing the symbol and shape to the previous list. This completes our definition.

data FeedList (pl :: [(Symbol, [Nat])]) where
  EmptyFeedList :: FeedList '[]
  (:--:) :: (KnownSymbol n)
    => (SafeTensor Value a s p, SafeTensorData a n s) 
    -> FeedList pl
    -> FeedList ( '(n, s) ': pl)

infixr 5 :--:

Note that we force the tensor to be a Value tensor. We can only substitute data for rendered tensors, hence this restriction. Let's add a quick safeRender so we can render our SafeTensor items.

safeRender :: (MonadBuild m) => SafeTensor Build a s pl -> m (SafeTensor Value a s pl)
safeRender (SafeTensor t1) = do
  t2 <- render t1
  return $ SafeTensor t2

Making a Placeholder

Now we can write our safePlaceholder function. We’ll add a KnownSymbol as a type constraint. Then we’ll take a SafeShape to give ourselves the type information for the shape. The result is a new tensor that maps the symbol and the shape in the placeholder list.

safePlaceholder :: (MonadBuild m, TensorType a, KnownSymbol sym) => 
  SafeShape s -> m (SafeTensor Value a s '[ '(sym, s)])
safePlaceholder shp = do
  pl <- placeholder (toShape shp)
  return $ SafeTensor pl

This looks a little crazy, and it kind’ve is! But we’ve now created a tensor that stores its own placeholder information at the type level!

Updating Old Code

Now that we’ve done this, we’re also going to have to update some of our older code. The first part of this is pretty straightforward. We’ll need to change safeConstant so that it has the type information. It will have an empty list for the placeholders.

safeConstant :: (TensorType a, ShapeProduct s ~ n) => 
  Vector n a -> SafeShape s -> SafeTensor Build a s '[]
safeConstant elems shp = SafeTensor (constant (toShape shp) (toList elems))

Our mathematical operations will be a bit more tricky though. Consider adding two arbitrary tensors. They may share placeholder dependencies but may not. What should be the placeholder type for the resulting tensor? Obviously the union of the two placeholder maps of the input tensors! Luckily for us, we can use Union from the type-list library to represent this concept.

safeAdd :: (TensorType a, a /= Bool, TensorKind v)
  => SafeTensor v a s p1
  -> SafeTensor v a s p2
  -> SafeTensor Build a s (Union p1 p2)
safeAdd (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `add` t2)

We’ll make the same update with matrix multiplication:

safeMatMul :: (TensorType a, a /= Bool, a /= Int8, a /= Int16,
               a /= Int64, a /= Word8, a /= ByteString, TensorKind v)
   => SafeTensor v a '[i,n] p1 -> SafeTensor v a '[n,o] p2 -> SafeTensor Build a '[i,o] (Union p1 p2)
safeMatMul (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `matMul` t2)

Running with Placeholders

Now we have all the information we need to write our safeRun function. This will take a SafeTensor, and it will also take a FeedList with the same placeholder type. Remember, a FeedList contains a series of SafeTensorData items. They must match up symbol-for-symbol and shape-for-shape with the placeholders within the SafeTensor. Let’s look at the type signature:

safeRun :: (TensorType a, Fetchable (Tensor v a) r) =>
  FeedList pl -> SafeTensor v a s pl -> Session r

The Fetchable constraint enforces that we can actually get the “result” r out of our tensor. For instance, we can "fetch" a vector of floats out of a tensor that uses Float as its underlying value.

We’ll next define a tail-recursive helper function to build the vanilla “list of feeds” out of our FeedList. Through pattern matching, we can pick out the tensor to substitute for and the data we’re using. We can combine these into a feed and append to the growing list:

safeRun = ...
  where
    buildFeedList :: FeedList ss -> [Feed] -> [Feed]
    buildFeedList EmptyFeedList accum = accum
    buildFeedList ((SafeTensor tensor_, SafeTensorData data_) :--: rest) accum = 
      buildFeedList rest ((feed tensor_ data_) : accum)

Now all we have to do to finish up is call the normal runWithFeeds function with the list we’ve created!

safeRun :: (TensorType a, Fetchable (Tensor v a) r) =>
  FeedList pl -> SafeTensor v a s pl -> Session r
safeRun feeds (SafeTensor finalTensor) = runWithFeeds (buildFeedList feeds []) finalTensor
  where
  ...

And here’s what it looks like to use this in practice with our simple example. Notice the type signatures do get a little cumbersome. The signatures we place on the initial placeholder tensors are necessary. Otherwise the compiler wouldn't know what label we're giving them! The signature containing the union of the types is unnecessary. We can remove it if we want and let type inference do its work.

main3 :: IO (VN.Vector Float)
main3 = runSession $ do
  let (shape1 :: SafeShape '[2,2]) = fromJust $ fromShape (Shape [2,2])
  (a :: SafeTensor Value Float '[2,2] '[ '("a", '[2,2])]) <- safePlaceholder shape1
  (b :: SafeTensor Value Float '[2,2] '[ '("b", '[2,2])] ) <- safePlaceholder shape1
  let result = a `safeAdd` b
  (result_ :: SafeTensor Value Float '[2,2] '[ '("b", '[2,2]), '("a", '[2,2])]) <- safeRender result
  let (feedA :: Vector 4 Float) = fromJust $ fromList [1,2,3,4]
  let (feedB :: Vector 4 Float) = fromJust $ fromList [5,6,7,8]
  let fullFeedList = (b, safeEncodeTensorData shape1 feedB) :--:
                     (a, safeEncodeTensorData shape1 feedA) :--:
                     EmptyFeedList
  safeRun fullFeedList result_

{- It runs!
[6.0,8.0,10.0,12.0]
-}

Now suppose we make some mistakes with our types. Here we’ll take out the “A” feed from our feed list:

-- Let’s take out Feed A!
main = …
  let fullFeedList = (b, safeEncodeTensorData shape1 feedB) :--:
                     EmptyFeedList
  safeRun fullFeedList result_

{- Compiler Error!
• Couldn't match type ‘'['("a", '[2, 2])]’ with ‘'[]’
      Expected type: SafeTensor Value Float '[2, 2] '['("b", '[2, 2])]
        Actual type: SafeTensor
                       Value Float '[2, 2] '['("b", '[2, 2]), '("a", '[2, 2])]
-}

Here’s what happens when we try to substitute a vector with the wrong size. It will identify that we have the wrong number of elements!

main = …
  -- Wrong Size!
  let (feedA :: Vector 8 Float) = fromJust $ fromList [1,2,3,4,5,6,7,8]
  let (feedB :: Vector 4 Float) = fromJust $ fromList [5,6,7,8]
  let fullFeedList = (b, safeEncodeTensorData shape1 feedB) :--:
                     (a, safeEncodeTensorData shape1 feedA) :--:
                     EmptyFeedList
  safeRun fullFeedList result_

{- Compiler Error!
Couldn't match type ‘4’ with ‘8’
        arising from a use of ‘safeEncodeTensorData’
-}

Conclusion: Pros and Cons

So let’s take a step back and look at what we’ve constructed here. We’ve managed to provide ourselves with some pretty cool compile time guarantees. We’ve also added de-facto documentation to our code. Anyone familiar with the codebase can tell at a glance what placeholders we need for each tensor. It’s a lot harder now to write incorrect code. There are still error conditions of course. But if we’re smart we can write our code to deal with these all upfront. That way we can fail gracefully instead of throwing a random run-time crash somewhere.

But there are drawbacks. Imagine being a Haskell novice and walking into this codebase. You’ll have no real clue what’s going on (I wouldn’t have 2 months ago). The types are very cumbersome after a while, so continuing to write them down gets very tedious. Though as I mentioned, type inference can deal with a lot of that. But if you don’t track them, the type union can be finicky about the ordering of your placeholders. We could fix this with another type family though.

All these factors could present a real drag on development. But then again, tracking down run-time errors can also do this. Tensor Flow’s error messages can still be a little cryptic. This can make it hard to find root causes.

Since I’m still a novice with dependent types, this code was a little messy. Next week we’ll take a look at a more polished library that uses dependent types for neural networks. We’ll see how the Grenade library allows us to specify a learning system in just a few lines of code!

If you’re new to Haskell, I hope none of this dependent type madness scared you! The language is much easier than these last couple posts make it seem! Try it out, and download our Getting Started Checklist. It'll give you some instructions and tools to help you learn!

If you’re an experienced Haskeller and want to try out Tensor Flow, download our Tensor Flow Guide! It will walk you through incorporating the library into a Stack project!

Appendix: Compiler Extensions and Imports

{-# LANGUAGE GADTs                #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE UndecidableInstances #-}

import           Data.ByteString (ByteString)
import           Data.Int (Int64, Int8, Int16)
import           Data.Maybe (fromJust)
import           Data.Proxy (Proxy(..))
import           Data.Type.List (Union)
import qualified Data.Vector as VN
import           Data.Vector.Sized (Vector, toList, fromList)
import           Data.Word (Word8)
import           GHC.TypeLits (Nat, KnownNat, natVal)
import           GHC.TypeLits

import           TensorFlow.Core
import           TensorFlow.Core (Shape(..), TensorType, Tensor, Build)
import           TensorFlow.Ops (constant, add, matMul, placeholder)
import           TensorFlow.Session (runSession, run)
import           TensorFlow.Tensor (TensorKind)
Read More
James Bowen James Bowen

Deep Learning and Deep Types: Tensor Flow and Dependent Types

In the introduction to this series, one primary point I made was that Haskell is a safe language. There are a lot of errors we will catch at compile time, rather than runtime. Runtime errors can often be catastrophic to a system, so being able to reduce these is paramount. This is especially true when programming an autonomous car or drone. These objects will be out in the real world where they can hurt people if they malfunction.

So let’s take a look back at some of the code we’ve written over the last 3 or 4 weeks. Is it actually any safer? We’ll find the answer is, well, not so much. It's hard to verify certain properties about code. But the facilities for making this code safer do exist in Haskell! In the next couple articles we'll do some serious hacking with dependent types. We'll be able to prove some of these difficult properties of AI programs at compile time!

The next three articles will focus on dependent type programming. This is a difficult topic, so don’t worry if you can’t follow all the code examples completely. The main idea of making our machine learning code safer is what’s important! So without further ado, let’s rewind to the beginning to see where runtime issues can appear.

If you want to play with this code yourself, check out the dependent shapes branch on my Github repository! All the code for this article is in DepShape.hs Though if you want to get the code to run, you'll probably also need to get Haskell Tensor Flow working. Download our Haskell Tensor Flow Guide for instructions on that!

Issues with Python

Python, as an interpreted language, is definitely subject to runtime bugs. As I was first learning Tensor Flow, I came across a lot of these that were quite common. The two that stood out to me most were placeholder failures and dimension mismatches. For instance, let’s think back to one of the first examples. Our code will have a couple of placeholders, and we submit values for those when we run the session:

node1 = tf.placeholder(tf.float32)
node2 = tf.placeholder(tf.float32)
adderNode = tf.add(node1, node2)
sess = tf.Session()
result1 = sess.run(adderNode, {node1: 3, node2: 4.5 })

But there’s nothing stopping us from trying to run the session without submitting values. This will result in a runtime crash:

...
sess = tf.Session()
result1 = sess.run(adderNode)
print(result1)
…

Terminal Output:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype float
   [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Another issue that came up from time to time was dimension mismatches. Certain operations need certain relationships between the dimensions of the tensors. For instance, you can’t add two vectors with different lengths:

node1 = tf.constant([3.0, 4.0, 5.0], dtype=tf.float32)
node2 = tf.constant([4.0, 16.0], dtype=tf.float32)
additionNode = tf.add(node1, node2)

sess = tf.Session()
result = sess.run(additionNode)
print(result)

…

Terminal Output:

ValueError: Dimensions must be equal, but are 3 and 2 for 'Add' (op: 'Add') with input shapes: [3], [2].

Again, we get a runtime crash. These seem like the kinds of problems we can solve at compile time.

Does Haskell Solve these Issues?

But anyone who takes a close look at the Haskell code I’ve written so far can see that it doesn’t solve these issues! Here’s a review of our basic placeholder example:

runPlaceholder :: Vector Float -> Vector Float -> IO (Vector Float)
runPlaceholder input1 input2 = runSession $ do
  (node1 :: Tensor Value Float) <- placeholder [1]
  (node2 :: Tensor Value Float) <- placeholder [1]
  let adderNode = node1 `add` node2
  let runStep = \node1Feed node2Feed -> runWithFeeds 
        [ feed node1 node1Feed
        , feed node2 node2Feed
        ] 
        adderNode
  runStep (encodeTensorData [1] input1) (encodeTensorData [1] input2)

Notice how the runWithFeeds function takes a list of Feed objects. The code would still compile fine if we supplied the empty list. Then it would face a fate no better than our Python code:

…
let runStep = \node1Feed node2Feed -> runWithFeeds [] adderNode
…

Terminal Output:

TensorFlowException TF_INVALID_ARGUMENT "You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [1]\n\t [[Node: Placeholder_1 = Placeholder[dtype=DT_FLOAT, shape=[1], _device=\"/job:localhost/replica:0/task:0/cpu:0\"]()]]"

For the second example of dimensionality, we can also make this mistake in Haskell. The following code compiles and will crash at runtime:

runSimple :: IO (Vector Float)
runSimple = runSession $ do
  let node1 = constant [3] [3 :: Float, 4, 5]
  let node2 = constant [2] [4 :: Float, 5]
  let additionNode = node1 `add` node2
  run additionNode
…

Terminal Output:
TensorFlowException TF_INVALID_ARGUMENT "Incompatible shapes: [3] vs. [2]\n\t [[Node: Add_2 = Add[T=DT_FLOAT, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](Const_0, Const_1)]]"

At an even more basic level, we don’t even have to tell the truth about the shape of our vectors! We can give a bogus shape value and it will still compile!

let node1 = constant [3, 2, 3] [3 :: Float, 4, 5]
…

Terminal Output:
invalid tensor length: expected 18 got 3
CallStack (from HasCallStack):
  error, called at src/TensorFlow/Ops.hs:299:23 in tensorflow-ops-0.1.0.0-EWsy8DQdciaL8o6yb2fUKR:TensorFlow.Ops

Can we do better?

Now, we did do some things right. Let's think back to our Model type when we made neural networks.

data Model = Model
  { train :: TensorData Float
          -> TensorData Int64
          -> Session ()
  , errorRate :: TensorData Float
              -> TensorData Int64
              -> SummaryTensor
              -> Session (Float, ByteString)
  }

We exposed our training step as a function. This function forced the user to supply both of the tensors for the placeholders. This is good, but doesn't protect us from dimension issues.

When trying to solve these, we could write wrappers around every operation. Functions like add and matMul could return Maybe values. But this would be clunky. We could take this same step in Python. Granted, monads would allow the Haskell version to compose better. But it would be nicer if we could check our errors all at once, up front.

If we’re willing to dig quite a bit deeper, we can solve these problems! In the rest of this post, we’ll explore using dependent types to ensure dimensions are always correct. Getting placeholders right turns out to be a little more complicated though! So we’ll save that for next week’s post.

Checking Dimensions

Currently, the Tensor Types we’ve been dealing with have no type safety on the dimensions. Tensor Flow doesn't provide this information when interacting with the C library. So it’s impossible to enforce it at a low level. But this doesn’t stop us from writing wrappers that allow us to solve this.

To write these wrappers, we’re going to need to dive into dependent types. I’ll give a high level overview of what’s going on. But for some details on the basics, you should check out this tutorial . I’ll also give a shout-out to Renzo Carbonara, author of the Exinst library and other great Haskell things. He helped me a lot in crossing a couple big knowledge gaps for implementing dependent types.

Intro to Dependent Types: Sized Vectors

The simplest example for introducing dependent types is the idea of sized vectors. If you read the tutorial above, you'll see how they're implemented from scratch. A normal vector has a single type parameter, referring to what type of item the vector contains. A sized vector has an extra type parameter, and this type refers to the size of the vector. For instance, the following are valid sized vector types:

import Data.Vector.Sized (Vector, fromList)

vectorWith2 :: Vector 2 Int64
...
vectorWith6 :: Vector 6 Float
...

In the first type signature, 2 does not refer to the term 2. It refers to the type 2. That is, we’ve taken the term and promoted it to a type which has only a single value. The mechanics of how this works are confusing, but here’s the result. We can try to convert normal vectors to sized vectors. But the operation will fail if we don’t match up the size.

import Data.Vector.Sized (Vector, fromList)
import GHC.TypeLits (KnownNat)

-- fromList :: (KnownNat n) => [a] -> Maybe (Vector n a)

-- This results in a “Just” value!
success :: Maybe (Vector 2 Int64)
success = fromList [5,6]

-- The sizes don’t match, so we’ll get “Nothing”!
failure :: Maybe (Vector 2 Int64)
failure = fromList [3,1,5]

The KnownNat constraint allows us to specify that the type n refers to a single natural number. So now we can assign a type signature that encapsulates the size of the list.

A “Safe” Shape type

Now that we have a very basic understanding of dependent types, let's come up with a gameplan for Tensor Flow. The first step will be to make a new type that puts the shape into the type signature. We'll make a SafeShape type that mimics the sized vector type. Instead of storing a single number as the type, it will store the full list of dimensions. We want to create an API something like this:

-- fromShape :: Shape -> Maybe (SafeShape s)

-- Results in a “Just” value
goodShape :: Maybe (SafeShape ‘[2, 2])
goodShape = fromShape (Shape [2,2])

-- Results in Nothing
badShape :: Maybe (SafeShape ‘[2,2])
badShape = fromShape (Shape [3,3,2])

So to do this, we first define the SafeShape type. This follows the example of sized vectors. See the appendix below for compiler extensions and imports used throughout this article. In particular, you want GADTs and DataKinds.

data SafeShape (s :: [Nat]) where
  NilShape :: SafeShape '[]
  (:--) :: KnownNat m => Proxy m -> SafeShape s -> SafeShape (m ': s)

infixr 5 :--

Now we can define the toShape function. This will take our SafeShape and turn it into a normal Shape using proxies.

toShape :: SafeShape s -> Shape
toShape NilShape = Shape []
toShape ((pm :: Proxy m) :-- s) = Shape (fromInteger (natVal pm) : s')
  where
    (Shape s') = toShape s

Now for the reverse direction, we first have to make a class MkSafeShape. This class encapsulates all the types that we can turn into the SafeShape type. We’ll define instances of this class for all lists of naturals.

class MkSafeShape (s :: [Nat]) where
  mkSafeShape :: SafeShape s
instance MkSafeShape '[] where
  mkSafeShape = NilShape
instance (MkSafeShape s, KnownNat m) => MkSafeShape (m ': s) where
  mkSafeShape = Proxy :-- mkSafeShape

Now we can define our fromShape function using the MkSafeShape class. To check if it works, we’ll compare the resulting shape to the input shape and make sure they’re equal. Note this requires us to define a simple instance of Eq Shape.

instance Eq Shape where
  (==) (Shape s) (Shape r) = s == r

fromShape :: forall s. MkSafeShape s => Shape -> Maybe (SafeShape s)
fromShape shape = if toShape myShape == shape
  then Just myShape
  else Nothing
  where
    myShape = mkSafeShape :: SafeShape s

Now that we’ve done this for Shape, we can create a similar type for Tensor that will store the shape as a type parameter.

data SafeTensor v a (s :: [Nat]) where
  SafeTensor :: (TensorType a) => Tensor v a -> SafeTensor v a s

Using our Safe Types

So what has all this gotten us? Our next goal is to create a safeConstant function. This will let us create a SafeTensor wrapping a constant tensor and storing the shape. Remember, constant takes a shape and a vector without ensuring correlation between them. We want something like this:

safeConstant :: (TensorType a) => Vector n a -> SafeShape s -> SafeTensor Build a s
safeConstant elems shp = SafeTensor $ constant (toShape shp) (toList elems)

This will attach the given shape to the tensor. But there’s one piece missing. We also want to create a connection between the number of input elements and the shape. So something with shape [3,3,2] should force you to input a vector of length 18. And right now, there is no constraint between n and s.

We’ll add this with a type family called ShapeProduct. The instances will state that the correct natural type for a given list of naturals is the product of them. We define the second instance with recursion, so we'll need UndecidableInstances.

type family ShapeProduct (s :: [Nat]) :: Nat
type instance ShapeProduct '[] = 1
type instance ShapeProduct (m ': s) = m * ShapeProduct s

Now we’re almost done with this part! We can fix our safeConstant function by adding a constraint on the ShapeProduct between s and n.

safeConstant :: (TensorType a, ShapeProduct s ~ n) => Vector n a -> SafeShape s -> SafeTensor Build a s
safeConstant elems shp = SafeTensor $ constant (toShape shp) (toList elems)

Now we can write out a simple use of our safeConstant function as follows:

main :: IO (VN.Vector Int64)
main = runSession $ do
  let (shape1 :: SafeShape '[2,2]) = fromJust $ fromShape (Shape [2,2])
  let (elems1 :: Vector 4 Int64) = fromJust $ fromList [1,2,3,4]
  let (constant1 :: SafeTensor Build Int64 '[2,2]) = safeConstant elems1 shape1
  let (SafeTensor t) = constant1
  run t

We’re using fromJust as a shortcut here. But in a real program you would read your initial tensors in and check them as Maybe values. There's still the possibility for runtime failures. But this system has a couple advantages. First, it won't crash. We'll have the opportunity to handle it gracefully. Second, we do all the error checking up front. Once we've assigned types to everything, all the failure cases should be covered.

Going back to the last example, let's change something. For instance, we could make our vector have length 3 instead of 4. We’ll now get a compile error!

main :: IO (VN.Vector Int64)
main = runSession $ do
  let (shape1 :: SafeShape '[2,2]) = fromJust $ fromShape (Shape [2,2])
  let (elems1 :: Vector 3 Int64) = fromJust $ fromList [1,2,3]
  let (constant1 :: SafeTensor Build Int64 '[2,2]) = safeConstant elems1 shape1
  let (SafeTensor t) = constant1
  run t

…

    • Couldn't match type ‘4’ with ‘3’
        arising from a use of ‘safeConstant’
    • In the expression: safeConstant elems1 shape1
      In a pattern binding:
        (constant1 :: SafeTensor Build Int64 '[2, 2])
          = safeConstant elems1 shape1

Adding Type Safe Operations

Now that we’ve attached shape information to our tensors, we can define safer math operations. It's easy to write a safe addition function that ensures that the tensors have the same shape:

safeAdd :: (TensorType a, a /= Bool) => SafeTensor Build a s -> SafeTensor Build a s -> SafeTensor Build a s
safeAdd (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `add` t2)

Here’s a similar matrix multiplication function. It ensures we have 2-dimensional shapes and that the dimensions work out. Notice the two tensors share the n dimension. It must be the column dimension of the first tensor and the row dimension of the second tensor:

safeMatMul :: (TensorType a, a /= Bool, a /= Int8, a /= Int16, a /= Int64, a /= Word8, a /= ByteString)
   => SafeTensor Build a '[i,n] -> SafeTensor Build a '[n,o] -> SafeTensor Build a '[i,o]
safeMatMul (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `matMul` t2)

Here are these functions in action:

main2 :: IO (VN.Vector Float)
main2 = runSession $ do
  let (shape1 :: SafeShape '[4,3]) = fromJust $ fromShape (Shape [4,3])
  let (shape2 :: SafeShape '[3,2]) = fromJust $ fromShape (Shape [3,2])
  let (shape3 :: SafeShape '[4,2]) = fromJust $ fromShape (Shape [4,2])
  let (elems1 :: Vector 12 Float) = fromJust $ fromList [1,2,3,4,1,2,3,4,1,2,3,4]
  let (elems2 :: Vector 6 Float) = fromJust $ fromList [5,6,7,8,9,10]
  let (elems3 :: Vector 8 Float) = fromJust $ fromList [11,12,13,14,15,16,17,18]
  let (constant1 :: SafeTensor Build Float '[4,3]) = safeConstant elems1 shape1
  let (constant2 :: SafeTensor Build Float '[3,2]) = safeConstant elems2 shape2
  let (constant3 :: SafeTensor Build Float '[4,2]) = safeConstant elems3 shape3
  let (multTensor :: SafeTensor Build Float '[4,2]) = constant1 `safeMatMul` constant2
  let (addTensor :: SafeTensor Build Float '[4,2]) = multTensor `safeAdd` constant3
  let (SafeTensor finalTensor) = addTensor
  run finalTensor

And of course we’ll get compile errors if we use incorrect dimensions anywhere. Let’s say we change multTensor to use [4,3] as its type:

• Couldn't match type ‘2’ with ‘3’
      Expected type: SafeTensor Build Float '[4, 3]
        Actual type: SafeTensor Build Float '[4, 2]
    • In the expression: constant1 `safeMatMul` constant2
…
 • Couldn't match type ‘3’ with ‘2’
      Expected type: SafeTensor Build Float '[4, 2]
        Actual type: SafeTensor Build Float '[4, 3]
    • In the expression: multTensor `safeAdd` constant3
…
• Couldn't match type ‘2’ with ‘3’
      Expected type: SafeTensor Build Float '[4, 3]
        Actual type: SafeTensor Build Float '[4, 2]
    • In the second argument of ‘safeAdd’, namely ‘constant3’

Conclusion

In this exercise we got deep into the weeds of one of the most difficult topics to learn about in Haskell. Dependent types will make your head spin at first. But we saw a concrete example of how they can allow us to detect problematic code at compile time. They are a form of documentation that also enables us to verify that our code is correct in certain ways.

Types do not replace tests (especially behavioral tests). But in this instance there are at least a few different test cases we don’t need to worry about too much. Next week, we’ll see how we can apply these principles to verifying placeholders.

If you want to learn more about the nuts and bolts of using Haskell Tensor Flow, you should check out our Tensor Flow Guide. It will guide you through the basics of adding Tensor Flow to a simple Stack project.

Maybe you’ve never used Haskell before but I’ve convinced you that dependent types are the future. If you want to try it out, download our Getting Started Checklist. You can also learn how to create and organize Haskell projects using Stack! Checkout our Stack mini-course!

Appendix: Extensions and Imports

{-# LANGUAGE GADTs                #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE UndecidableInstances #-}

import           Data.ByteString (ByteString)
import           Data.Constraint (Constraint)
import           Data.Int (Int64, Int8, Int16)
import           Data.Maybe (fromJust)
import           Data.Proxy (Proxy(..))
import qualified Data.Vector as VN
import           Data.Vector.Sized (Vector(..), toList, fromList)
import           Data.Word (Word8)
import           GHC.TypeLits (Nat, KnownNat, natVal)
import           GHC.TypeLits

import           TensorFlow.Core
import           TensorFlow.Core (Shape(..), TensorType, Tensor, Build)
import           TensorFlow.Ops (constant, add, matMul)
import           TensorFlow.Session (runSession, run)
Read More
James Bowen James Bowen

Deeper Still: Convolutional Neural Networks

Two weeks ago, we began our machine study in earnest by constructing a full neural network. But this network was still quite simple by deep learning standards. In this article, we’re going to tackle a much more difficult problem: image recognition. Of course, we’ll still be using a well known data set with well-known results, so this is only the tip of the iceberg. We'll be using the MNIST data set. This set classifies images of handwritten digits as the numbers 0-9. This problem is so well-known that the folks at Tensor Flow refer to it as the “Hello World” of machine learning.

We’ll start this problem by using a very similar approach to what we used with the Iris data set. We’ll make a fully-connected neural network with two layers, and then use the “Adam” optimizer. This will give us some decent results by our beginner standards. But MNIST is a well known problem with a very large data set. So we’re going to hold ourselves to a higher standard of accuracy this time. This will force us to use some more advanced techniques. But to start with, let’s examine what we need to change to adapt our Iris model to work for the MNIST problem. As with the last couple weeks, the code for all this is on Github if you want to follow along.

Re-use and Recycle!

Generally, we can re-use most of the code we had with Iris, which is good news! We still have to make a few adjustments here and there though. First, we’ll use some different constants. We’ll use mnistFeatures in place of irisFeatures, and mnistLabels instead of irisLabels. We’ll also bump up the size of our hidden layer and the number of samples we’ll draw on each iteration:

mnistFeatures :: Int64
mnistFeatures = 784

mnistLabels :: Int64
mnistLabels = 10

numHiddenUnits :: Int64
numHiddenUnits = 1024

sampleSize :: Int
sampleSize = 1000

We’ll also change our model to use Word8 as the result type instead Int64.

data Model = Model
 { train :: TensorData Float
         -> TensorData Word8 -- Used to be Int64
         -> Session ()
 , errorRate :: TensorData Float
             -> TensorData Word8 -- Used to be Int64
             -> SummaryTensor
             -> Session (Float, ByteString)
 }

Now we have to change how we get our input data. Our data isn’t in CSV format this time. We’ll use helper functions from the Tensor Flow library to extract the images and labels:

import TensorFlow.Examples.MNIST.Parse (readMNISTSamples, readMNISTLabels)
…
runDigits :: FilePath -> FilePath -> FilePath -> FilePath -> IO ()
runDigits trainImageFile trainLabelFile testImageFile testLabelFile = 
 withEventWriter eventsDir $ \eventWriter -> runSession $ do

   -- trainingImages, testImages :: [Vector Word8]
   trainingImages <- liftIO $ readMNISTSamples trainImageFile
   testImages <- liftIO $ readMNISTSamples testImageFile

   -- traininglabels, testLabels :: [Word8]
   trainingLabels <- liftIO $ readMNISTLabels trainLabelFile
   testLabels <- liftIO $ readMNISTLabels testLabelFile

   -- trainingRecords, testRecords :: Vector (Vector Word8, Word8)
   let trainingRecords = fromList $ zip trainingImages trainingLabels
   let testRecords = fromList $ zip testImages testLabels
   ...

Our “input” type consists of vectors of Word8 elements. These represent the intensity of various pixels. Our “output” type is Word8, referring to the actual labels (0-9). We read the images and labels from separate files. Then we zip them together to pass to our processing functions. We’ll have to make a few changes to these processing functions for this data set. First, we have to generalize the type of our randomization function:

-- Used to be IrisRecord Specific
chooseRandomRecords :: Vector a -> IO (Vector a)

Next we have to write a new encoding function that will put our data into the TensorData format. This looks like our old version, except dealing with the new tuple type instead of the IrisRecord.

convertDigitRecordsToTensorData 
 :: Vector (Vector Word8, Word8)
 -> (TensorData Float, TensorData Word8)
convertDigitRecordsToTensorData records = (input, output)
 where
   numRecords = Data.Vector.length records 
   input = encodeTensorData [fromIntegral numRecords, mnistFeatures]
     (fromList $ concatMap recordToInputs records)
   output = encodeTensorData [fromIntegral numRecords] (snd <$> records)
   recordToInputs :: (Vector Word8, Word8) -> [Float]
   recordToInputs rec = fromIntegral <$> (toList . fst) rec

And then we just have to substitute our new functions and parameters in, and we’ll be able to run our digit trainer!

Current training error 89.8
Current training error 19.300001
Current training error 13.300001
Current training error 11.199999
Current training error 8.700001
Current training error 6.5999985
Current training error 6.999999
Current training error 5.199999
Current training error 4.400003
Current training error 5.000001
Current training error 2.3000002

test error 6.830001

So our accuracy is 93.2%. This seems like an alright number. But imagine being a Post office and having 6.8% of your mail sorted into the wrong Zip Code! (This was the original use case of this data set). So let’s see if we can do better.

Convolution and Max Pooling

Now we could train our model longer. This will tend to improve our error rate. But we can also help ourselves by making our model more complex. The fundamental flaw with what we’ve got so far is that it doesn’t account for the 2D nature of the images. This means we're losing a ton of useful information. So the first thing we'll do is treat our images as being 28x28 tensors instead of 1x784. This way, our model can pick out specific areas that are significant for the identification of the digit.

One thing we want to account for is that our image might not be in the center of the frame. To account for this, we're going to apply convolution. When using convolution, we break the image into many different overlapping tiles. In our case, we’ll make our strides size “1” in every direction, and we’ll use a patch size of 5x5. So this means we’ll center a 5x5 tile around each different pixel in our image, and then come up with a score for it. That score tells us if this part of the image contains any important information. We can represent this score as a vector with many features.

So with 2D convolution, we'll be dealing with 4-dimensional tensors. The first dimension is the sample size. The second two dimensions are the shape of the image. The final dimension is the number of features of the "score" for each part of the image. So each original image starts our with a single feature for the "score" of each pixel. This score is the actual intensity of that pixel! Then each layer of convolution will act as a mini neural network per pixel, making as many features as we want.

The different sliding windows correspond to scores we store in the next layer. This example uses 3x3 patches; we'll use 5x5.

The different sliding windows correspond to scores we store in the next layer. This example uses 3x3 patches; we'll use 5x5.

Max pooling is a form of down-sampling. After our first convolution step, we’ll have scores on the 28x28 image. We’ll use 2x2 max-pooling, meaning we divide each image into 2x2 squares. Then we’ll make a new layer that is 14x14, using only the “best” score from each 2x2 box. This makes our model more efficient while keeping the most important information.

Simple demonstration of max-pooling

Simple demonstration of max-pooling

Implementing Convolutional Layers

We’ll do two rounds of convolution and max pooling. So we’ll make a function that creates a layer that performs these two steps. This will look a lot like our other neural network layer. We’ll take parameters for the size of the input and output channels of the layer, as well as the tensor itself. So our first step will be to create the weights and bias tensors using these parameters:

patchSize :: Int64
patchSize = 5

buildConvPoolLayer :: Int64 -> Int64 -> Tensor v Float -> Text
                  -> Build (Variable Float, Variable Float, Tensor Build Float)
buildConvPoolLayer inputChannels outputChannels input layerName = withNameScope layerName $ do
 weights <- truncatedNormal (vector weightsShape)
   >>= initializedVariable
 bias <- truncatedNormal (vector [outputChannels]) >>= initializedVariable
 ...
 where
   weightsShape :: [Int64]
   weightsShape = [patchSize, patchSize, inputChannels, outputChannels]

Now we’ll want to call our convolution and max pooling functions. These are still a little rough around the edges (the Haskell library is still quite young). The C versions of these functions have many optional, named attributes. For the moment there don’t seem to be any functions that use normal Haskell values for these arguments. Instead, we’ll be using OpAttr values, assign bytestring names to values.

where
 ...
 convStridesAttr = opAttr "strides" .~ ([1,1,1,1] :: [Int64])
 poolStridesAttr = opAttr "strides" .~ ([1,2,2,1] :: [Int64])
 poolKSizeAttr = opAttr "ksize" .~ ([1,2,2,1] :: [Int64])
 paddingAttr = opAttr "padding" .~ ("SAME" :: ByteString)
 dataFormatAttr = opAttr "data_format" .~ ("NHWC" :: ByteString)
 convAttrs = convStridesAttr . paddingAttr . dataFormatAttr
 poolAttrs = poolKSizeAttr . poolStridesAttr . paddingAttr . dataFormatAttr

The strides argument for convolution refers to how much we shift the window each time. The strides argument for pooling refers to how big the windows will be that we perform the pooling over. In this case, it's 2x2. Now that we have our attributes, we can call the library functions conv2D’ and maxPool’. This gives our resulting vector. We also throw in a call to relu between these steps.

buildConvPoolLayer :: Int64 -> Int64 -> Tensor v Float -> Text
                  -> Build (Variable Float, Variable Float, Tensor Build Float)
buildConvPoolLayer inputChannels outputChannels input layerName = withNameScope layerName $ do
 weights <- truncatedNormal (vector weightsShape)
   >>= initializedVariable
 bias <- truncatedNormal (vector [outputChannels]) >>= initializedVariable
 let conv = conv2D' convAttrs input (readValue weights)
       `add` readValue bias
 let results = maxPool' poolAttrs (relu conv)
 return (weights, bias, results)
 where
   ...

Modifying our Model

Now we’ll make a few updates to our model and we’ll be in good shape. First, we need to reshape our input data to be 4-dimensional. Then, we’ll apply the two convolution/pooling layers:

imageDimen :: Int32
imageDimen = 28

createModel :: Build Model
createModel = do
 let batchSize = -1 -- Allows variable sized batches
 let conv1OutputChannels = 32
 let conv2OutputChannels = 64
 let denseInputSize = 7 * 7 * 64 :: Int32 -- 3136
 let numHiddenUnits = 1024

 inputs <- placeholder [batchSize, mnistFeatures]
 outputs <- placeholder [batchSize]

 let inputImage = reshape inputs (vector [batchSize, imageDimen, imageDimen, 1])

 (convWeights1, convBiases1, convResults1) <- 
   buildConvPoolLayer 1 conv1OutputChannels inputImage "convLayer1"
 (convWeights2, convBiases2, convResults2) <-
   buildConvPoolLayer conv1OutputChannels conv2OutputChannels convResults1 "convLayer2"

Once we’re done with that, we’ll apply two fully-connected (dense) layers as we did before. Note we'll reshape our result from four dimensions back down to two:

let denseInput = reshape convResults2 (vector [batchSize, denseInputSize])
(denseWeights1, denseBiases1, denseResults1) <-
  buildNNLayer (fromIntegral denseInputSize) numHiddenUnits denseInput "denseLayer1"  
let rectifiedDenseResults = relu denseResults1
(denseWeights2, denseBiases2, denseResults2) <-
   buildNNLayer numHiddenUnits mnistLabels rectifiedDenseResults "denseLayer2"

And after that we can treat the rest of the model the same. We'll update the parameter names and add the new weights and biases to the params that the model can change.

As a review, let’s look at the dimensions of each of the intermediate tensors here. Then we can see the restrictions on the dimensions of the different operations. Each convolution step takes two four-dimensional tensors. The final dimension of argument 1 must match the third dimension of argument 2. Then the result will swap in the final dimension of argument 2. Meanwhile, pooling with a 2x2 stride size will take this resulting 4-dimensional tensor and halve each of the inner dimensions.

input: n x 784
inputImage: n x 28 x 28 x 1
convWeights1: 5 x 5 x 1 x 32
convBias1: 32
conv (first layer): n x 28 x 28 x 32
convResults1: n x 14 x 14 x 32
convWeights2:  5 x 5 x 32 x 64
conv (second layer): n x 14 x 14 x 64
convResults2: n x 7 x 7 x 64
denseInput: n x 3136
denseWeights1: 3136 x 1024
denseBias1: 1024
denseResults1: n x 1024
denseWeights2: 1024 x 10
denseBias2: 10
denseResults2: n x 10

So for each input, we’ll have a probability for all 10 of the possible inputs. We pick the greatest of these as the chosen label.

Results

We’ll run our model again, only this time we’ll use a smaller sample size (100 per training iteration). This allows us to train for more iterations (20000). This takes quite a while to train, but we get these results (printed every 1000 iterations).

Current training error 91.0
Current training error 6.0
Current training error 2.9999971
Current training error 2.9999971
Current training error 0.0
Current training error 0.0
Current training error 0.99999905
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0

test error 1.1799991

Not too bad! Once it got going, we saw very few training errors, though still ended up with a model that was a tad overfit. The Tensor Flow MNIST expert tutorial suggests using a dropout factor. This helps diminish the effects of overfitting. But this option isn’t available to us yet in Haskell. Still, we got close to 99% accuracy, which is a success for us!

And here’s what our final graph looks like. Notice the extra layers we added for convolution:

Conclusion

So that’s it for convolutional neural networks! Our goal was to adapt our previous neural network model to recognize digits. Not only did we pick a harder problem, but we also wanted higher accuracy. We achieved this by using more advanced machine learning techniques. Convolution allowed us to use the full 2x2 nature of the image. It also checked for the digit no matter where it was in the image. Max pooling enabled us to make our algorithm more efficient while keeping the most important information.

If you’re itching to see what else Haskell can do with Tensor Flow, check out our Tensor Flow Guide. It’ll walk you through some of the trickier parts of getting the library working on your local machine. It will also go through the most important information you need to know about the types in this library.

If you’re new to Haskell, here are a couple more resources you can dig into before you try your hand at Tensor Flow. First, there’s our Getting Started Checklist. This will point you toward some helpful resources for learning the language. Next, you can check out our Stack mini-course so you can learn how to organize a project! If you want to use Tensor Flow in Haskell, Stack is a must!

Read More
James Bowen James Bowen

Putting the Flow in Tensor Flow!

Last week we built our first neural network and used it on a real machine learning problem. We took the Iris data set and built a classifier that took in various flower measurements. It determined, with decent accuracy, the type of flower the measurements referred to.

But we’ve still only seen half the story of Tensor Flow! We’ve constructed many tensors and combined them in interesting ways. We can imagine what is going on with the “flow”, but we haven’t seen a visual representation of that yet.

We’re in luck though, thanks to the Tensor Board application. With it, we can visualize the computation graph we've created. We can also track certain values throughout our program run. In this article, we’ll take our Iris example and show how we can add Tensor Board features to it. Here's the Github repo with all the code so you can follow along!

Add an Event Writer

The first thing to understand about Tensor Board is that it gets its data from a source directory. While we’re running our system, we have to direct it to write events to that directory. This will allow Tensor Board to know what happened in our training run.

eventsDir :: FilePath
eventsDir = "/tmp/tensorflow/iris/logs/"

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = withEventWriter eventsDir $ \eventWriter -> runSession $ do
...

By itself, this doesn’t write anything down into that directory though! To understand the consequences of this, let’s boot up tensor board.

Running Tensor Board

Running our executable again doesn't bring up tensor board. It merely logs the information that Tensor Board uses. To actually see that information, we’ll run the tensorboard command.

>> tensorboard --logdir=’/tmp/tensorflow/iris/logs’
Starting TensorBoard 47 at http://0.0.0.0:6006

Then we can point our web browser at the correct port. Since we haven't written anything to the file yet, there won’t be much for us to see other than some pretty graphics. So let’s start by logging our graph. This is actually quite easy! Remember our model? We can use the logGraph function combined with our event writer so we can see it.

model <- build createModel
logGraph eventWriter createModel

Now when we refresh Tensor Flow, we’ll see our system’s graph.

What the heck is going on here?

What the heck is going on here?

But, it’s very large and very confusing. The names of all the nodes are a little confusing, and it’s not clear what data is going where. Plus, we have no idea what’s going on with our error rate or anything like that. Let’s make a couple adjustments to fix this.

Adding Summaries

So the first step is to actually specify some measurements that we’ll have Tensor Board plot for us. One node we can use is a “scalar summary”. This provides us with a summary of a particular value over the course of our training run. Let’s do this with our errorRate node. We can use the simple scalarSummary function.

errorRate_ <- render $ 1 - (reduceMean (cast correctPredictions))
scalarSummary "Error" errorRate_

The second type of summary is a histogram summary. We use this on a particular tensor to see the distribution of its values over the course of the run. Let’s do this with our second set of weights. We need to use readValue to go from a Variable to a Tensor.

(finalWeights, finalBiases, finalResults) <-
  buildNNLayer numHiddenUnits irisLabels rectifiedHiddenResults
histogramSummary "Weights" (readValue finalWeights)

So let’s run tensor flow again. We would expect to see these new values show up under the Scalars and Histograms tabs. But they don’t. This is because we still to write these results to our event writer. And this turns out to be a little complicated. First, before we start training, we have to create a tensor representing all our summaries.

logGraph eventWriter createModel
summaryTensor <- build mergeAllSummaries

Now if we had no placeholders, we could run this tensor whenever we wanted, and it would output the values. But our summary tensors depend on the input placeholders, which complicates the matter. So here’s what we’ll do. We’ll only write out the summaries when we check our error rate (every 100 steps). To do this, we have to change our error rate in the model to take the summary tensor as an extra argument. We’ll also have it add a ByteString as a return value to the original Float.

data Model = Model
  { train :: TensorData Float
          -> TensorData Int64
          -> Session ()
  , errorRate :: TensorData Float
              -> TensorData Int64
              -> SummaryTensor
              -> Session (Float, ByteString)
  }

Within our model definition, we’ll use this extra parameter. It will run both the errorRate_ tensor AND the summary tensor together with the feeds:

return $ Model
  , train = ...
  , errorRate = \inputFeed outputFeed summaryTensor -> do
      (errorTensorResult, summaryTensorResult) <- runWithFeeds
        [ feed inputs inputFeed
        , feed outputs outputFeed
        ]
        (errorRate_, summaryTensor)
      return (unScalar errorTensorResult, unScalar summaryTensorResult)

Now we need to modify our calls to errorRate below. We’ll pass the summary tensor as an argument, and get the bytes as output. We’ll write it to our event writer (after decoding), and then we’ll be done!

-- Training
  forM_ ([0..1000] :: [Int]) $ \i -> do
    trainingSample <- liftIO $ chooseRandomRecords trainingRecords
    let (trainingInputs, trainingOutputs) = convertRecordsToTensorData trainingSample
    (train model) trainingInputs trainingOutputs
    when (i `mod` 100 == 0) $ do
      (err, summaryBytes) <- (errorRate model) trainingInputs trainingOutputs summaryTensor
      let summary = decodeMessageOrDie summaryBytes
      liftIO $ putStrLn $ "Current training error " ++ show (err * 100)
      logSummary eventWriter (fromIntegral i) summary

  liftIO $ putStrLn ""

  -- Testing
  let (testingInputs, testingOutputs) = convertRecordsToTensorData testRecords
  (testingError, _) <- (errorRate model) testingInputs testingOutputs summaryTensor
  liftIO $ putStrLn $ "test error " ++ show (testingError * 100)

Now we can see what our summaries look like by running tensor board again!

Scalar Summary of our Error Rate

Scalar Summary of our Error Rate

Histogram summary of our final weights.

Histogram summary of our final weights.

Annotating our Graph

Now let’s look back to our graph. It’s still a bit confusing. We can clean it up a lot by creating “name scopes”. A name scope is part of the graph that we set aside under a single name. When Tensor Board generates our graph, it will create one big block for the scope. We can then zoom in and examine the individual nodes if we want.

We’ll make three different scopes. First, we’ll make a scope for each of the hidden layers of our neural network. This is quite easy, since we already have a function for creating these. All we have to do is make the function take an extra parameter for the name of the scope we want. Then we wrap the whole function within the withNameScope function.

buildNNLayer :: Int64 -> Int64 -> Tensor v Float -> Text
             -> Build (Variable Float, Variable Float, Tensor Build Float)
buildNNLayer inputSize outputSize input layerName = withNameScope layerName $ do
  weights <- truncatedNormal (vector [inputSize, outputSize]) >>= initializedVariable
  bias <- truncatedNormal (vector [outputSize]) >>= initializedVariable
  let results = (input `matMul` readValue weights) `add` readValue bias
  return (weights, bias, results)

We supply our name further down in the code:

(hiddenWeights, hiddenBiases, hiddenResults) <- 
  buildNNLayer irisFeatures numHiddenUnits inputs "layer1"
let rectifiedHiddenResults = relu hiddenResults
(finalWeights, finalBiases, finalResults) <-
  buildNNLayer numHiddenUnits irisLabels rectifiedHiddenResults "layer2"

Now we’ll add a scope around all our error calculations. First, we combine these into an action wrapped in withNameScope. Then, observing that we need the errorRate_ and train_ steps, we return those from the block. That’s it!

(errorRate_, train_) <- withNameScope "error_calculation" $ do
    actualOutput <- render $ cast $ argMax finalResults (scalar (1 :: Int64))
    let correctPredictions = equal actualOutput outputs
    er <- render $ 1 - (reduceMean (cast correctPredictions))
    scalarSummary "Error" er

    let outputVectors = oneHot outputs (fromIntegral irisLabels) 1 0
    let loss = reduceMean $ fst $ softmaxCrossEntropyWithLogits finalResults outputVectors
    let params = [hiddenWeights, hiddenBiases, finalWeights, finalBiases]
    tr <- minimizeWith adam loss params
    return (er, tr)

Now when we look at our graph, we see that it’s divided into three parts: our two layers, and our error calculation. All the information flows among these three parts (as well as the "Adam" optimizer portion).

Much Better

Much Better

Conclusion

By default, Tensor Board graphs can look a little messy. But by adding a little more information to the nodes and using scopes, you can paint a much clearer picture. You can see how the data flows from one end of the application to the other. We can also use summaries to track important information about our graph. We’ll use this most often for the loss function or error rate. Hopefully, we'll see it decline over time.

Next week we’ll add some more complexity to our neural networks. We'll see new tensors for convolution and max pooling. This will allow us to solve the more difficult MNIST digit recognition problem. Stay tuned!

If you’re itching to try out some Tensor Board functionality for yourself, check out our in-depth Tensor Flow guide. It goes into more detail about the practical aspects of using this library. If you want to get the Haskell Tensor Flow library running on your local machine, check it out! Trust me, it's a little complicated, unless you're a Stack wizard already!

And if this is your first exposure to Haskell, try it out! Take a look at our guide to getting started with the language!

Read More
James Bowen James Bowen

Digging in Deep: Solving a Real Problem with Haskell Tensor Flow

Last week we got acquainted with the core concepts of Tensor Flow. We learned about the differences between constants, placeholders, and variable tensors. Both the Haskell and Python bindings have functions to represent these. The Python version was a bit simpler though. Once we had our tensors, we wrote a program that “learned” a simple linear equation.

This week, we’re going to solve an actual machine learning problem. We’re going to use the Iris data set, which contains measurements of different Iris flowers. Each flower belongs to one of three species. Our program will "learn" a function choosing the species from the measurements. This function will involved a fully-connected neural network.

Formatting our Input

The first step in pretty much any machine learning problem is data processing. After all, our data doesn’t magically get resolved into Haskell data types. Luckily, Cassava is a great library to help us out. The Iris data set consists of data in .csv files that each have a header line and then a series of records. They look a bit like this:

120,4,setosa,versicolor,virginica
6.4,2.8,5.6,2.2,2
5.0,2.3,3.3,1.0,1
4.9,2.5,4.5,1.7,2
4.9,3.1,1.5,0.1,0
...

Each line contains one record. A record has four flower measurements, and a final label. In this case, we have three types of flowers we are trying to classify between: Iris Setosa, Iris Versicolor, and Iris Virginica. So the last column contains the numbers 0,1, and 2, corresponding to these respective classes.

Let's create a data type representing each record. Then we can parse the file line-by-line. Our IrisRecord type will contain the feature data and the resulting label. This type will act as a bridge between our raw data and the tensor format we’ll need to run our learning algorithm. We’ll derive the “Generic” typeclass for our record type, and use this to get FromRecord. Once our type has an instance for FromRecord, we can parse it with ease. As a note, throughout this article, I’ll be omitting the imports section from the code samples. I’ve included a full list of imports from these files as an appendix at the bottom. We'll also be using the OverloadedLists extension throughout.

{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedLists #-}

...

data IrisRecord = IrisRecord
  { field1 :: Float
  , field2 :: Float
  , field3 :: Float
  , field4 :: Float
  , label  :: Int64
  }
  deriving (Generic)

instance FromRecord IrisRecord

Now that we have our type, we’ll write a function, readIrisFromFile, that will read our data in from a CSV file.

readIrisFromFile :: FilePath -> IO (Vector IrisRecord)
readIrisFromFile fp = do
  contents <- readFile fp
  let contentsAsBs = pack contents
  let results = decode HasHeader contentsAsBs :: Either String (Vector IrisRecord)
  case results of
    Left err -> error err
    Right records -> return records

We won’t want to always feed our entire data set into our training system. So given a whole slew of these items, we should be able to pick out a random sample.

sampleSize :: Int
sampleSize = 10

chooseRandomRecords :: Vector IrisRecord -> IO (Vector IrisRecord)
chooseRandomRecords records = do
  let numRecords = Data.Vector.length records
  chosenIndices <- take sampleSize <$> shuffleM [0..(numRecords - 1)]
  return $ fromList $ map (records !) chosenIndices

Once we’ve selected our vector of records to use for each run, we’re still not done. We need to take these records and transform them into the TensorData that we’ll feed into our algorithm. We create items of TensorData by feeding in a shape and then a 1-dimensional vector of values. First, we need to know the shapes of our input and output. Both of these depend on the number of rows in the sample. The “input” will have a column for each of the four features in our set. The output meanwhile will have a single column for the label values.

irisFeatures :: Int64
irisFeatures = 4

irisLabels :: Int64
irisLabels = 3

convertRecordsToTensorData :: Vector IrisRecord -> (TensorData Float, TensorData Int64)
convertRecordsToTensorData records = (input, output)
  where
    numRecords = Data.Vector.length records 
    input = encodeTensorData [fromIntegral numRecords, irisFeatures] (undefined)
    output = encodeTensorData [fromIntegral numRecords] (undefined)

Now all we need to do is take the various records and turn them into one dimensional vectors to encode. Here’s the final function:

convertRecordsToTensorData :: Vector IrisRecord -> (TensorData Float, TensorData Int64)
convertRecordsToTensorData records = (input, output)
  where
    numRecords = Data.Vector.length records 
    input = encodeTensorData [fromIntegral numRecords, irisFeatures]
      (fromList $ concatMap recordToInputs records)
    output = encodeTensorData [fromIntegral numRecords] (label <$> records)
    recordToInputs :: IrisRecord -> [Float]
    recordToInputs rec = [field1 rec, field2 rec, field3 rec, field4 rec]

Neural Network Basics

Now that we’ve got that out of the way, we can start writing our model. Remember, we want to perform two different actions with our model. First, we want to be able to take our training input and train the weights. Second, we want to be able to pass a test data set and determine the error rate. We can represent these two different functions as a single Model object. Remember the Session monad, where we run all our Tensor Flow activities. The training will run an action that alters the variables but returns nothing. The error rate calculation will return us a float value.

data Model = Model
  { train :: TensorData Float -- Training input
          -> TensorData Int64 -- Training output
          -> Session ()
  , errorRate :: TensorData Float -- Test input
              -> TensorData Int64 -- Test output
              -> Session Float
  }

Now we’re going to build a fully-connected neural network. We’ll have 4 input units (1 for each of the different features), and then we’ll have 3 output units (1 for each of the classes we’re trying to represent). In the middle, we’ll use a hidden layer consisting of 10 units. This means we’ll need two sets of weights and biases. We’ll write a function that, when given dimensions, will give us the variable tensors for each layer. We want the weight and bias tensors, plus the result tensor of the layer.

buildNNLayer :: Int64 -> Int64 -> Tensor v Float
             -> Build (Variable Float, Variable Float, Tensor Build Float)
buildNNLayer inputSize outputSize input = do
  weights <- truncatedNormal (vector [inputSize, outputSize]) >>= initializedVariable
  bias <- truncatedNormal (vector [outputSize]) >>= initializedVariable
  let results = (input `matMul` readValue weights) `add` readValue bias
  return (weights, bias, results)

We do this in the Build monad, which allows us to construct variables, among other things. We’ll use a truncatedNormal distribution for all our variables to keep things simple. We specify the size of each variable in a vector tensor, and then initialize them. Then we’ll create the resulting tensor by multiplying the input by our weights and adding the bias.

Constructing our Model

Now we’ll start building our Model object, again within the Build monad. We begin by specifying our input and output placeholders, as well the number of hidden units. We’ll also use a batchSize of -1 to account for the fact that we want a variable number of input samples.

irisFeatures :: Int64
irisFeatures = 4

irisLabels :: Int64
irisLabels = 3
-- ^^ From above

createModel :: Build Model
createModel = do
  let batchSize = -1 -- Allows variable sized batches
  let numHiddenUnits = 10
  inputs <- placeholder [batchSize, irisFeatures]
  outputs <- placeholder [batchSize]

Then we’ll get the nodes for the two layers of variables, as well as their results. Between the layers, we’ll add a “rectifier” activation function relu:

(hiddenWeights, hiddenBiases, hiddenResults) <- 
  buildNNLayer irisFeatures numHiddenUnits inputs
let rectifiedHiddenResults = relu hiddenResults
(finalWeights, finalBiases, finalResults) <-
  buildNNLayer numHiddenUnits irisLabels rectifiedHiddenResults

Now we have to get the inferred classes of each output. This means calling argMax to take the class with the highest probability. We’ll also cast the vector and then render it. These are some Haskell-Tensor-Flow specific terms for getting tensors to the right type. Next, we compare that against our output placeholders to see how many we got correct. Then we’ll make a node for calculating the error rate for this run.

actualOutput <- render $ cast $ argMax finalResults (scalar (1 :: Int64))
let correctPredictions = equal actualOutput outputs
errorRate_ <- render $ 1 - (reduceMean (cast correctPredictions))

Now we have to actually do the work of training. First, we’ll make oneHot vectors for our expected outputs. This means converting the label 0 into the vector [1,0,0], and so on. We’ll compare these values against our results (before we took the max), and this gives us our loss function. Then we will make a list of the parameters we want to train. The adam optimizer will minimize our loss function while modifying the params.

let outputVectors = oneHot outputs (fromIntegral irisLabels) 1 0
let loss = reduceMean $ fst $ softmaxCrossEntropyWithLogits finalResults outputVectors
let params = [hiddenWeights, hiddenBiases, finalWeights, finalBiases]
train_ <- minimizeWith adam loss params

Now we’ve got our errorRate_ and train_ nodes ready. There's one last step here. We have to plug in for the placeholder values and create functions that will take in the tensor data. Remember the feed pattern from last week? We use it again here. Finally, our model is complete!

return $ Model
  { train = \inputFeed outputFeed -> 
      runWithFeeds
        [ feed inputs inputFeed
        , feed outputs outputFeed
        ]
        train_
  , errorRate = \inputFeed outputFeed -> unScalar <$>
      runWithFeeds
        [ feed inputs inputFeed
        , feed outputs outputFeed
        ]
        errorRate_
  }

Tying it all together

Now we’ll write our main function that will run the session. It will have three stages. In the preparation stage, we’ll load our data, and use the build function to get our model. Then we’ll train our model for 1000 steps by choosing samples and converting our records to data. Every 100 steps, we'll print the output. Finally, we’ll determine the resulting error ratio by using the test data.

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = runSession $ do
  -- Preparation
  trainingRecords <- liftIO $ readIrisFromFile trainingFile
  testRecords <- liftIO $ readIrisFromFile testingFile
  model <- build createModel

  -- Training
  forM_ ([0..1000] :: [Int]) $ \i -> do
    trainingSample <- liftIO $ chooseRandomRecords trainingRecords
    let (trainingInputs, trainingOutputs) = convertRecordsToTensorData trainingSample
    (train model) trainingInputs trainingOutputs
    when (i `mod` 100 == 0) $ do
      err <- (errorRate model) trainingInputs trainingOutputs
      liftIO $ putStrLn $ "Current training error " ++ show (err * 100)

  liftIO $ putStrLn ""

  -- Testing
  let (testingInputs, testingOutputs) = convertRecordsToTensorData testRecords
  testingError <- (errorRate model) testingInputs testingOutputs
  liftIO $ putStrLn $ "test error " ++ show (testingError * 100)

  return ()

Results

So when we actually run all this output, we’ll get the following results on our test set.

Current training error 60.000004
Current training error 30.000002
Current training error 39.999996
Current training error 19.999998
Current training error 10.000002
Current training error 10.000002
Current training error 19.999998
Current training error 19.999998
Current training error 10.000002
Current training error 10.000002
Current training error 0.0

test error 3.333336

Our test sample size was 30, so this means we got 29/30 this time around. Results change though from run to run (I obviously used the best results I found). Since our sample size is so small, we have high entropy here (sometimes the error rate is like 40%). Generally we’ll want to train longer on a larger test set, so that we get more consistent results, but this is a good start.

Conclusion

In this article we went over the basics of making a neural network using the Haskell Tensor Flow library. We made a fully-connected neural network and fed in real data we parsed using the Cassava library. This network was able to learn a function to classify flowers from the Iris data set. Considering the small amount of data, we got some good results.

Come back next week, where we’ll see how we can add some more summary information to our tensor flow graph. We’ll use the tensor board application to view our graph in a visual format.

For more details on installing the Haskell Tensor Flow system, check out our In-Depth Tensor Flow Tutorial. It should walk you through the important steps in running the code on your own machine.

Perhaps you’ve never tried Haskell before at all, and want to see what it’s like. Maybe I’ve convinced you that Haskell is in fact the future of AI. In that case, you should check out our Getting Started Checklist for some tools on starting with the language.

Appendix: All Imports

Documentation for Haskell Tensor Flow is still a major work in progress. So I want to make sure I explicitly list the modules you need to import for all the different functions we used here.

import Control.Monad (forM_, when)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString.Lazy.Char8 (pack)
import Data.Csv (FromRecord, decode, HasHeader(..))
import Data.Int (Int64)
import Data.Vector (Vector, length, fromList, (!))
import GHC.Generics (Generic)
import System.Random.Shuffle (shuffleM)

import TensorFlow.Core (TensorData, Session, Build, render, runWithFeeds, feed, unScalar, build,
                        Tensor, encodeTensorData)
import TensorFlow.Minimize (minimizeWith, adam)
import TensorFlow.Ops (placeholder, truncatedNormal, add, matMul, relu,
                      argMax, scalar, cast, oneHot, reduceMean, softmaxCrossEntropyWithLogits, 
                      equal, vector)
import TensorFlow.Session (runSession)
import TensorFlow.Variable (readValue, initializedVariable, Variable)
Read More