James Bowen James Bowen

Tangled Webs: Testing an Integrated System

In the last few articles, we’ve combined several useful Haskell libraries to make a small web app. We used Persistent to create a schema with automatic migrations for our database. Then we used Servant to expose this database as an API through a couple simple queries. Finally, we used Redis to act as a cache so that repeated requests for the same user happen faster.

For our next step, we’ll tackle a thorny issue: testing. How do we test a system that has so many moving parts? There are a couple different general approaches we could take. On one end of the spectrum we could mock out most of our API calls and services. This helps gives our testing deterministic behavior. This is desirable since we would want to tie deployments to test results. But we also want to be faithfully testing our API. So on the other end, there’s the approach we’ll try in this article. We'll set up functions to run our API and other services, and then use before and after hooks to use them.

Creating Client Functions for Our API

Calling our API from our tests means we’ll want a way to make API calls programmatically. We can do this with amazing ease with a Servant API by using the servant-client library. This library has one main function: client. This function takes a proxy for our API and generates programmatic client functions. Let's remember our basic endpoint types (after resolving connection information parameters):

fetchUsersHandler :: Int64 -> Handler User
createUserHandler :: User -> Handler Int64

We’d like to be able to call these API’s with functions that use the same parameters. Those types might look something like this:

fetchUserClient :: Int64 -> m User
createUserClient :: User -> m Int64

Where m is some monad. And in this case, the ServantClient library provides such a monad, ClientM. So let’s re-write these type signatures, but leave them seemingly unimplemented:

fetchUserClient :: Int64 -> ClientM User
createUserClient :: User -> ClientM Int64

Now we’ll construct a pattern match that combines these function names with the :<|> operator. As always, we need to make sure we do this in the same order as the original API type. Then we’ll set this pattern to be the result of calling client on a proxy for our API:

fetchUserClient :: Int64 -> ClientM User
createUserClient :: User -> ClientM Int64
(fetchUserClient :<|> createUserClient) = client (Proxy :: Proxy UsersAPI)

And that’s it! The Servant library fills in the details for us and implements these functions! We’ll see how we can actually call these functions later in the article.

Setting Up the Tests

We’d like to get to the business of deciding on our test cases and writing them. But first we need to make sure that our tests have a proper environment. This means 3 things. First we need to fetch the connection information for our data stores and API. This means the PGInfo, the RedisInfo, and the ClientEnv we’ll use to call the client functions we wrote. Second, we need to actually migrate our database so it has the proper tables. Third, we need to make sure our server is actually running. Let’s start with the connection information, as this is easy:

import Database (fetchPostgresConnection, fetchRedisConnection)

...

setupTests = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  ...

Now to create our client environment, we’ll need two main things. We’ll need a manager for the network connections and the base URL for the API. Since we’re running the API locally, we’ll use a localhost URL. The default manager from the Network library will work fine for us:

import Network.HTTP.Client (newManager)
import Network.HTTP.Client.TLS (tlsManagerSettings)
import Servant.Client (ClientEnv(..))

main = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  mgr <- newManager tlsManagerSettings
  baseUrl <- parseBaseUrl "http://127.0.0.1:8000"
  let clientEnv = ClientEnv mgr baseUrl

Now we can run our migration, which will ensure that our users table exists:

import Schema (migrateAll)

main = do
  pgInfo <- fetchPostgresConnection
  ...
  runStdoutLoggingT $ withPostgresqlConn pgInfo $ \dbConn ->
    runReaderT (runMigrationSilent migrateAll) dbConn

Last of all, we’ll start our server with runServer from our API module. We’ll fork this off to a separate thread, as otherwise it will block the test thread! We’ll wait for a second afterward to make sure it actually loads before the tests run (there are less hacky ways to do this of course). But then we’ll return all the important information we need, and we're done with test setup:

main :: IO (PGInfo, RedisInfo, ClientEnv, ThreadID)
main = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  mgr <- newManager tlsManagerSettings
  baseUrl <- parseBaseUrl "http://127.0.0.1:8000"
  let clientEnv = ClientEnv mgr baseUrl
  runStdoutLoggingT $ withPostgresqlConn pgInfo $ \dbConn ->
    runReaderT (runMigrationSilent migrateAll) dbConn
  threadId <- forkIO runServer
  threadDelay 1000000
  return (pgInfo, redisInfo, clientEnv, serverThreadId)

Organizing our 3 Test Cases

Now that we’re all set up, we can decide on our test cases. We’ll look at 3. First, if we have an empty database and we fetch a user by some arbitrary ID, we’ll expect an error. Further, we should expect that the user does not exist in the database or in the cache, even after calling fetch.

In our second test case, we’ll look at the effects of calling the create endpoint. We’ll save the key we get from this endpoint. Then we’ll verify that this user exists in the database, but NOT in the cache. Finally, our third case will insert the user with the create endpoint and then fetch the user. We’ll expect at the end of this that in fact the user exists in both the database AND the cache.

We organize each of our tests into up to three parts: the “before hook”, the test assertions, and the “after hook”. A “before hook” is some IO code we’ll run that will return particular results to our test assertion. We want to make sure it’s done running BEFORE any test assertions. This way, there’s no interleaving of effects between our test output and the API calls. Each before hook will first make the API calls we want. Then they'll investigate our different databases and determine if certain users exist.

We also want our tests to be database-neutral. That is, the database and cache should be in the same state after the test as they were before. So we’ll also have “after hooks” that run after our tests have finished (if we’ve actually created anything). The after hooks will delete any new entries. This means our before hooks also have to pass the keys for any database entities they create. This way the after hooks know what to delete.

Last of course, we actually need the testing code that assertions about the results. These will be pretty straightforward as we’ll see below.

Test #1

For our first test, we’ll start by making a client call to our API. We use runClientM combined with our clientEnv and the fetchUserClient function. Next, we’ll determine that the call in fact returns an error as it should. Then we’ll add two more lines checking if there’s an entry with the arbitrary ID in our database and our cache. Finally, we return all three boolean values:

beforeHook1 :: ClientEnv -> PGInfo -> RedisInfo -> IO (Bool, Bool, Bool)
beforeHook1 clientEnv pgInfo redisInfo = do
  callResult <- runClientM (fetchUserClient 1) clientEnv
  let throwsError = isLeft (callResult)
  inPG <- isJust <$> fetchUserPG pgInfo 1
  inRedis <- isJust <$> fetchUserRedis redisInfo 1
  return (throwsError, inPG, inRedis)

Now we’ll write our assertion. Since we’re using a before hook returning three booleans, the type of our Spec will be SpecWith (Bool, Bool, Bool). Each it assertion will take this boolean tuple as a parameter, though we’ll only use one for each line.

spec1 :: SpecWith (Bool, Bool, Bool)
spec1 = describe "After fetching on an empty database" $ do
  it "The fetch call should throw an error" $ \(throwsError, _, _) -> throwsError `shouldBe` True
  it "There should be no user in Postgres" $ \(_, inPG, _) -> inPG `shouldBe` False
  it "There should be no user in Redis" $ \(_, _, inRedis) -> inRedis `shouldBe` False

And that’s all we need for the first test! We don’t need an after hook since it doesn’t add anything to our database.

Tests 2 and 3

Now that we’re a little more familiar with how this code works, let’s take a quick look at the next before hook. This time we’ll first try creating our user. If this fails for whatever reason, we’ll throw an error and stop the tests. Then we can use the key to check out if the user exists in our database and Redis. We return the boolean values and the key.

beforeHook2 :: ClientEnv -> PGInfo -> RedisInfo -> IO (Bool, Bool, Int64)
beforeHook2 clientEnv pgInfo redisInfo = do
  userKeyEither <- runClientM (createUserClient testUser) clientEnv
  case userKeyEither of
    Left _ -> error "DB call failed on spec 2!"
    Right userKey -> do 
      inPG <- isJust <$> fetchUserPG pgInfo userKey
      inRedis <- isJust <$> fetchUserRedis redisInfo userKey
      return (inPG, inRedis, userKey)

Now our spec will look similar. This time we expect to find a user in Postgres, but not in Redis.

spec2 :: SpecWith (Bool, Bool, Int64)
spec2 = describe "After creating the user but not fetching" $ do
  it "There should be a user in Postgres" $ \(inPG, _, _) -> inPG `shouldBe` True
  it "There should be no user in Redis" $ \(_, inRedis, _) -> inRedis `shouldBe` False

Now we need to add the after hook, which will delete the user from the database and cache. Of course, we expect the user won’t exist in the cache, but we include this since we’ll need it in the final example:

afterHook :: PGInfo -> RedisInfo -> (Bool, Bool, Int64) -> IO ()
afterHook pgInfo redisInfo (_, _, key) = do
  deleteUserCache redisInfo key
  deleteUserPG pgInfo key

Last, we’ll write one more test case. This will mimic the previous case, except we’ll throw in a call to fetch in between. As a result, we expect the user to be in both Postgres and Redis:

beforeHook3 :: ClientEnv -> PGInfo -> RedisInfo -> IO (Bool, Bool, Int64)
beforeHook3 clientEnv pgInfo redisInfo = do
  userKeyEither <- runClientM (createUserClient testUser) clientEnv
  case userKeyEither of
    Left _ -> error "DB call failed on spec 3!"
    Right userKey -> do 
      _ <- runClientM (fetchUserClient userKey) clientEnv 
      inPG <- isJust <$> fetchUserPG pgInfo userKey
      inRedis <- isJust <$> fetchUserRedis redisInfo userKey
      return (inPG, inRedis, userKey)

spec3 :: SpecWith (Bool, Bool, Int64)
spec3 = describe "After creating the user and fetching" $ do
  it "There should be a user in Postgres" $ \(inPG, _, _) -> inPG `shouldBe` True
  it "There should be a user in Redis" $ \(_, inRedis, _) -> inRedis `shouldBe` True

And it will use the same after hook as case 2, so we’re done!

Hooking in and running the tests

The last step is to glue all our pieces together with hspec, before, and after. Here’s our main function, which also kills the thread running the server once it’s done:

main :: IO ()
main = do
  (pgInfo, redisInfo, clientEnv, tid) <- setupTests
  hspec $ before (beforeHook1 clientEnv pgInfo redisInfo) spec1
  hspec $ before (beforeHook2 clientEnv pgInfo redisInfo) $ after (afterHook pgInfo redisInfo) $ spec2
  hspec $ before (beforeHook3 clientEnv pgInfo redisInfo) $ after (afterHook pgInfo redisInfo) $ spec3
  killThread tid 
  return ()

And now our tests should pass!

After fetching on an empty database
  The fetch call should throw an error
  There should be no user in Postgres
  There should be no user in Redis

Finished in 0.0410 seconds
3 examples, 0 failures

After creating the user but not fetching
  There should be a user in Postgres
  There should be no user in Redis

Finished in 0.0585 seconds
2 examples, 0 failures

After creating the user and fetching
  There should be a user in Postgres
  There should be a user in Redis

Finished in 0.0813 seconds
2 examples, 0 failures

Using Docker

So when I say, “the tests pass”, they now work on my system. But if you were to clone the code as is and try to run them, you would get failures. The tests depend on Postgres and Redis, so if you don't have them running, they fail! It is quite annoying to have your tests depend on these outside services. This is the weakness of devising our tests as we have. It increases the on-boarding time for anyone coming into your codebase. The new person has to figure out which things they need to run, install them, and so on.

So how do we fix this? One answer is by using Docker. Docker allows you to create containers that have particular services running within them. This spares you from worrying about the details of setting up the services on your local machine. Even more important, you can deploy a docker image to your remote environments. So develop and prod will match your local system. To setup this process, we’ll create a description of the services we want running on our Docker container. We do this with a docker-compose file. Here’s what ours looks like:

version: '2'

services:
  postgres:
    image: postgres:9.6
    container_name: prod-haskell-series-postgres
    ports:
      - "5432:5432"

  redis:
    image: redis:4.0
    container_name: prod-haskell-series-redis
    ports:
      - "6379:6379"

Then, you can start these services for your Docker machines with docker-compose up. Granted, you do have to install and run Docker. But if you have several different services, this is a much easier on-boarding process. Better yet, the "compose" file ensures everyone uses the same versions of these services.

Even with this container running, the tests will still fail! That’s because you also need the tests themselves to be running on your Docker cluster. But with Stack, this is easy! We’ll add the following flag to our stack.yaml file:

docker:
  enable: true

Now, whenever you build and test your program, you will do so on Docker. The first time you do this, Docker will need to set everything up on the container. This means it will have to download Stack and ALL the different packages you use. So the first run will take a while. But subsequent runs will be normal. So after all that finishes, NOW the tests should work!

Conclusion

Testing integrated systems is hard. We can try mocking out the behavior of external services. But this can lead to a test representation of our program that isn’t faithful to the production system. But using the before and after hooks from Hspec is a great way make sure all your external events happen first. Then you can pass those results to simpler test assertions.

When it comes time to run your system, it helps if you can bring up all your external services with one command! Docker allows you to do this by listing the different services in the docker-compose file. Then, Stack makes it easy to run your program and tests on a docker container, so you can use the services!

Stack is the key to all this integration. If you’ve never used Stack before, you should check out our free mini-course. It will teach you all the basics of organizing a Haskell project using Stack.

If this is your first exposure to Haskell, I’ve hopefully convinced of some of its awesome possibilities! Take a look at our Getting Started Checklist and get learning!

Read More
James Bowen James Bowen

A Cache is Fast: Enhancing our API with Redis

In the last couple weeks we’ve used Persistent to store a User type in a Postgresql database. Then we were able to use Servant to create a very simple API that exposed this database to the outside world. This week, we’re going to look at how we can improve the performance of our API using a Redis cache.

One cannot overstate the importance of caching in both software and hardware. There's a hierarchy of memory types from registers, to RAM, to the File system, to a remote database. Accessing each of these gets progressively slower (by orders of magnitude). But the faster means of storage are more expensive, so we can’t always have as much as we'd like.

But memory usage operates on a very important principle. When we use a piece of memory once, we’re very likely to use it again in the near-future. So when we pull something out of long-term memory, we can temporarily store it in short-term memory as well. This way when we need it again, we can get it faster. After a certain point, that item will be overwritten by other more urgent items. This is the essence of caching.

Redis 101

Redis is an application that allows us to create a key-value store of items. It functions like a database, except it only uses these keys. It lacks the sophistication of joins, foreign table references and indices. So we can’t run the kinds of sophisticated queries that are possible on an SQL database. But we can run simple key lookups, and we can do them faster. In this article, we'll use Redis as a short-term cache for our user objects.

For this article, we've got one main goal for cache integration. Whenever we “fetch” a user using the GET endpoint in our API, we want to store that user in our Redis cache. Then the next time someone requests that user from our API, we'll grab them out of the cache. This will save us the trouble of making a longer call to our Postgres database.

Connecting to Redis

Haskell's Redis library has a lot of similarities to Persistent and Postgres. First, we’ll need some sort of data that tells us where to look for our database. For Postgres, we used a simple ConnectionString with a particular format. Redis uses a full data type called ConnectInfo.

data ConnectInfo = ConnectInfo
  { connectHost :: HostName -- String
  , connectPort :: PortId   -- (Can just be a number)
  , connectAuth :: Maybe ByteString
  , connectDatabase :: Integer
  , connectMaxConnection :: Int
  , connectMaxIdleTime :: NominalDiffTime
  }

This has many of the same fields we stored in our PG string, like the host IP address, and the port number. The rest of this article assumes you are running a local Redis instance at port 6379. This means we can use defaultConnectInfo. As always, in a real system you’d want to grab this information out of a configuration, so you’d need IO.

fetchRedisConnection :: IO ConnectInfo
fetchRedisConnection = return defaultConnectInfo

With Postgres, we used withPostgresqlConn to actually connect to the database. With Redis, we do this with the connect function. We'll get a Connection object that we can use to run Redis actions.

connect :: ConnectInfo -> IO Connection

With this connection, we simply use runRedis, and then combine it with an action. Here’s the wrapper runRedisAction we’ll write for that:

runRedisAction :: ConnectInfo -> Redis a -> IO a
runRedisAction redisInfo action = do
  connection <- connect redisInfo
  runRedis connection action

The Redis Monad

Just as we used the SqlPersistT monad with Persist, we’ll use the Redis monad to interact with our Redis cache. Our API is simple, so we’ll stick to three basic functions. The real types of these functions are a bit more complicated. But this is because of polymorphism related to transactions, and we won't be using those.

get :: ByteString -> Redis (Either x (Maybe ByteString))
set :: ByteString -> ByteString -> Redis (Either x ())
setex :: ByteString -> ByteString -> Int -> Redis (Either x ())

Redis is a key-value store, so everything we set here will use ByteString items. But once we’ve done that, these functions are all we need to use. The get function takes a ByteString of the key and delivers the value as another ByteString. The set function takes both the serialized key and value and stores them in the cache. The setex function does the same thing as set except that it also sets an expiration time for the item we’re storing.

Expiration is a very useful feature to be aware of, since most relational databases don’t have this. The nature of a cache is that it’s only supposed to store a subset of our information at any given time. If we never expire or delete anything, it might eventually store our whole database. That would defeat the purpose of using a cache! It's memory footprint should remain low compared to our database. So we'll use setex in our API.

Saving a User in Redis

So now let’s move on to the actions we’ll actually use in our API. First, we’ll write a function that will actually store a key-value pair of an Int64 key and the User in the database. Here’s how we start:

cacheUser :: ConnectInfo -> Int64 -> User -> IO ()
cacheUser redisInfo uid user = runRedisAction redisInfo $ setex ??? ??? ???

All we need to do now is convert our key and our value to ByteString values. We'll keep it simple and use Data.ByteString.Char8 combined with our Show and Read instances. Then we’ll create a Redis action using setex and expire the key after 3600 seconds (one hour).

import Data.ByteString.Char8 (pack, unpack)

...

cacheUser :: ConnectInfo -> Int64 -> User -> IO ()
cacheUser redisInfo uid user = runRedisAction redisInfo $ void $ 
  setex (pack . show $ uid) 3600 (pack . show $ user)

(We use void to ignore the result of the Redis call).

Fetching from Redis

Fetching a user is a similar process. We’ll take the connection information and the key we’re looking for. The action we’ll create uses the bytestring representation and calls get. But we can’t ignore the result of this call like we could before! Retrieving anything gives us Either e (Maybe ByteString). A Left response indicates an error, while Right Nothing indicates the key doesn’t exist. We’ll ignore the errors and treat the result as Maybe User though. If any error comes up, we’ll return Nothing. This means we run a simple pattern match:

fetchUserRedis :: ConnectInfo -> Int64 -> IO (Maybe User)
fetchUserRedis redisInfo uid = runRedisAction redisInfo $ do
  result <- Redis.get (pack . show $ uid)
  case result of
    Right (Just userString) -> return $ Just (read . unpack $ userString)
    _ -> return Nothing

If we do find something for that key, we’ll read it out of its ByteString format and then we’ll have our final User object.

Applying this to our API

Now that we’re all set up with our Redis functions, we have the update the fetchUsersHandler to use this cache. First, we now need to pass the Redis connection information as another parameter. For ease of reading, we’ll refer to these using type synonyms (PGInfo and RedisInfo) from now on:

type PGInfo = ConnectionString
type RedisInfo = ConnectInfo

…

fetchUsersHandler :: PGInfo -> RedisInfo -> Int64 -> Handler User
fetchUsersHandler pgInfo redisInfo uid = do
  ...

The first thing we’ll try is to look up the user by their ID in the Redis cache. If the user exists, we’ll immediately return that user.

fetchUsersHandler :: PGInfo -> RedisInfo -> Int64 -> Handler User
fetchUsersHandler pgInfo redisInfo uid = do
  maybeCachedUser <- liftIO $ fetchUserRedis redisInfo uid
  case maybeCachedUser of
    Just user -> return user
    Nothing -> do
      ...

If the user doesn’t exist, we’ll then drop into the logic of fetching the user in the database. We’ll replicate our logic of throwing an error if we find that user doesn’t actually exist. But if we find the user, we need one more step. Before we return it, we should call cacheUser and store it for the future.

fetchUsersHandler :: PGInfo -> RedisInfo -> Int64 -> Handler User
fetchUsersHandler pgInfo redisInfo uid = do
  maybeCachedUser <- liftIO $ fetchUserRedis redisInfo uid
  case maybeCachedUser of
    Just user -> return user
    Nothing -> do
      maybeUser <- liftIO $ fetchUserPG pgInfo uid
      case maybeUser of
        Just user -> liftIO (cacheUser redisInfo uid user) >> return user
        Nothing -> Handler $ (throwE $ err401 { errBody = "Could not find user with that ID" })

Since we changed our type signature, we’ll have to make a few other updates as well, but these are quite simple:

usersServer :: PGInfo -> RedisInfo -> Server UsersAPI
usersServer pgInfo redisInfo =
  (fetchUsersHandler pgInfo redisInfo) :<|> 
  (createUserHandler pgInfo)


runServer :: IO ()
runServer = do
  pgInfo <- fetchPostgresConnection
  redisInfo <- fetchRedisConnection
  run 8000 (serve usersAPI (usersServer pgInfo redisInfo))

And that’s it! We have a functioning cache with expiring entries. This means that repeated queries to our fetch endpoint should be much faster!

Conclusion

Caching is a vitally important way that we can write software that is often much faster for our users. Redis is a key-value store that we can use as a cache for our most frequently used data. We can use it as an alternative to forcing every single API call to hit our database. In Haskell, the Redis API requires everything to be a ByteString. So we have to deal with some logic surrounding encoding and decoding. But otherwise it operates in a very similar way to Persistent and Postgres.

Be sure to take a look at this code on Github! There’s a redis branch for this article. It includes all the code samples, including things I skipped over like imports!

We’re starting to get to the point where we’re using a lot of different libraries in our Haskell application! It pays to know how to organize everything, so package management is vital! I tend to use Stack for all my package management. It makes it quite easy to bring all these different libraries together. If you want to learn how to use Stack, check out our free Stack mini-course!

If you’ve never learned Haskell before, you should try it out! Download our Getting Started Checklist!

Read More
James Bowen James Bowen

Serve it up with Servant!

Last week we began our series on production Haskell techniques by learning about Persistent. We created a schema that contained a single User type that we could store in a Postgresql database. We examined a couple functions allowing us to make SQL queries about these users.

This week, we’ll see how we can expose this database to the outside world using an API. We’ll construct our API using the Servant library. Servant involves some advanced type level constructs, so there’s a lot to wrap your head around. There are definitely simpler approaches to HTTP servers than what Servant uses. But I’ve found that the power Servant gives us is well worth the effort.

This article will give a brief overview on Servant. But if you want a more in-depth introduction, you should check out my talk from Bayhac last spring! That talk was more exhaustive about the different combinators you can use in your APIs. It also showed authentication techniques, client functions and documentation. You can also check out the slides and code for that presentation!

Also, take a look at the servant branch on the Github repo for this project to see all the code for this article!

Defining our API

The first step in writing an API for our user database is to decide what the different endpoints are. We can decide this independent of what language or library we’ll use. For this article, our API will have two different endpoints. The first will be a POST request to /users. This request will contain a “user” definition in its body, and the result will be that we’ll create a user in our database. Here’s a sample of what this might look like:

POST /users
{
  userName : “John Doe”,
  userEmail : “john@doe.com”,
  userAge : 29,
  userOccupation: “Teacher”
}

It will then return a response containing the database key of the user we created. This will allow any clients to fetch the user again. The second endpoint will use the ID to fetch a user by their database identifier. It will be a GET request to /users/:userid. So for instance, the last request might have returned us something like 16. We could then do the following:

GET /users/16

And our response would look like the request body from above.

An API as a Type

So we’ve got our very simple API. How do we actually define this in Haskell, and more specifically with Servant? Well, Servant does something pretty unique (as far I’ve researched). In Servant we define our API by using a type. Our type will include sub-types for each of the endpoints of our API. We combine the different endpoints by using the (:<|>) operator. I'll sometimes refer to this as “E-plus”, for “endpoint-plus”. This is a type operator, like some of the operators we saw with dependent types and tensor flow. Here’s the blueprint of our API:

type UsersAPI = 
  fetchEndpoint
  :<|> createEndpoint

Now let's define what we mean by fetchEndpoint and createEndpoint. Endpoints combine different combinators that describe different information about the endpoint. We link combinators together with the (:>) operator, which I call “C-plus” (combinator plus). Here’s what our final API looks like. We’ll go through what each combinator means in the next section:

type UsersAPI =
       “users” :> Capture “userid” Int64 :> Get ‘[JSON] User
  :<|> “users” :> ReqBody ‘[JSON] User :> Post ‘[JSON] Int64

Different Combinators

Both of these endpoints have three different combinators. Let’s start by examining the fetch endpoint. It starts off with a string combinator. This is a path component, allowing us to specify what url extension the caller should use to hit to endpoint. We can use this combinator multiple times to have a more complicated path for the endpoint. If we instead wanted this endpoint to be at /api/users/:userid then we’d change it to:

“api” :> “users” :> Capture “userid” Int64 :> Get ‘[JSON] User

The second combinator (Capture) allows us to get a value out of the URL itself. We give this value a name and then we supply a type parameter. We won't have to do any path parsing or manipulation ourselves. Servant will handle the tricky business of parsing the URL and passing us an Int64. If you want to use your own custom class as a piece of HTTP data, that's not too difficult. You’ll just have to write an instance of the FromHttpApiData class. All the basic types like Int64 already have instances.

The final combinator itself contains three important pieces of information for this endpoint. First, it tells us that this is in fact a GET request. Second, it gives us the list of content-types that are allowable in the response. This is a type level list of content formats. Each type in this list must have different classes for serialization and deserialization of our data. We could have used a more complicated list like ’[JSON, PlainText, OctetStream]. But for the rest of this article, we’ll just use JSON. This means we'll use the ToJSON and FromJSON typeclasses for serialization.

The last piece of this combinator is the type our endpoint returns. So a successful request will give the caller back a response that contains a User in JSON format. Notice this isn’t a Maybe User. If the ID is not in our database, we’ll return a 401 error to indicate failure, rather than returning Nothing.

Our second endpoint has many similarities. It uses the same string path component. Then its final combinator is the same except that it indicates it is a POST request instead of a GET request. The second combinator then tells us what we can expect the request body to look like. In this case, the request body should contain a JSON representation of a User. It requires a list of acceptable content types, and then the type we want, like the Get and Post combinators.

That completes the “definition” of our API. We’ll need to add ToJSON and FromJSON instances of our User type in order for this to function. You can take a look at those on Github, and check out this article for more details on creating those instances!

Writing Handlers

Now that we’ve defined the type of our API, we need to write handler functions for each endpoint. This is where Servant’s awesomeness kicks in. We can map each endpoint up to a function that has a particular type based on the combinators in the endpoint. So, first let’s remember our endpoint for fetching a user:

“users” :> Capture “userid” Int64 :> Get ‘[JSON] User

The string path component doesn’t add any arguments to our function. The Capture component will result in a parameter of type Int64 that we’ll need in our function. Then the return type of our function should be User. This almost completely defines the type signature of our handler. We'll note though that it needs to be in the Handler monad. So here’s what it’ll look like:

fetchUsersHandler :: Int64 -> Handler User
...

Servant can also look at the type for our create endpoint:

“users” :> ReqBody ‘[JSON] User :> Post ‘[JSON] Int64

The parameter for a ReqBody parameter is just the type argument. So it will resolve the endpoint into the handler type:

createUserHandler :: User -> Handler Int64
...

Now, we’ll need to be able to access our Postgres database through both these handlers. So they’ll each get an extra parameter referring to the ConnectionString. We’ll pass that from our code so that by the time Servant is resolving the types, the parameter is accounted for:

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
createUserHandler :: ConnectionString -> User -> Handler Int64

Before we go any further, we should discuss the Handler monad. This is a wrapper around the monad ExceptT ServantErr IO. In other words, each of these requests might fail. To make it fail, we can throw errors of type ServantErr. Then of course we can also call IO functions, because these are network operations.

Before we implement these functions, let’s first write a couple simple helpers. These will use the runAction function from last week’s article to run database actions:

fetchUserPG :: ConnectionString -> Int64 -> IO (Maybe User)
fetchUserPG connString uid = runAction connString (get (toSqlKey uid))

createUserPG :: ConnectionString -> User -> IO Int64
createUserPG connString user = fromSqlKey <$> runAction connString (insert user)

For completeness (and use later in testing), we’ll also add a simple delete function. We need the where clause for type inference:

deleteUserPG :: ConnectionString -> Int64 -> IO ()
deleteUserPG connString uid = runAction connString (delete userKey)
  where
    userKey :: Key User
    userKey = toSqlKey uid

Now from our Servant handlers, we’ll call these two functions. This will completely cover the case of the create endpoint. But we’ll need a little bit more logic for the fetch endpoint. Since our functions are in the IO monad, we have to lift them up to Handler.

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
fetchUserHandler connString uid = do
  maybeUser <- liftIO $ fetchUserPG connString uid
  ...

createUserHandler :: ConnectionString -> User -> Handler Int64
createuserHandler connString user = liftIO $ createUserPG connString user

To complete our fetch handler, we need to account for a non-existent user. Instead of making the type of the whole endpoint a Maybe, we’ll throw a ServantErr in this case. We can use one of the built-in Servant error functions, which correspond to normal error codes. Then we can update the body. In this case, we’ll throw a 401 error. Here’s how we do that:

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
fetchUserHandler connString uid = do
  maybeUser <- lift $ fetchUserPG connString uid
  case maybeUser of
    Just user -> return user
    Nothing -> Handler $ (throwE $ err401 { errBody = “Could not find user with ID: “ ++ (show uid)})

createUserHandler :: ConnectionString -> User -> Handler Int64
createuserHandler connString user = lift $ createUserPG connString user

And that’s it! We're done with our handler functions!

Combining it All into a Server

Our next step is to create an object of type Server over our API. This is actually remarkably simple. When we defined the original type, we combined the endpoints with the (:<|>) operator. To make our Server, we do the same thing but with the handler functions:

usersServer :: ConnectionString -> Server UsersAPI
usersServer pgInfo = 
  (fetchUsersHandler pgInfo) :<|> 
  (createUserHandler pgInfo)

And Servant does all the work of ensuring that the type of each endpoint matches up with the type of the handler! It’s pretty awesome. Suppose we changed the type of our fetchUsersHandler so that it took a Key User instead of an Int64. We’d get a compile error:

fetchUsersHandler :: ConnectionString -> Int64 -> Handler User
…

-- Compile Error!
• Couldn't match type ‘Key User’ with ‘Int’
      Expected type: Server UsersAPI
        Actual type: (Key User -> Handler User)
                     :<|> (User -> Handler Int64)

There's now a mismatch between our API definition and our handler definition. So Servant knows to throw an error! The one issue is that the error messages can be rather difficult to interpret sometimes. This is especially the case when your API becomes very large! The “Actual type” section of the above error will become massive! So always be careful when changing your endpoints! Frequent compilation is your friend!

Building the Application

The final piece of the puzzle is to actually build an Application object out of our server. The first step of this process is to create a Proxy for our API. Remember that our API is a type, and not a term. But a Proxy allows us to represent this type at the term level. The concept is a little complicated, but the code is not!

import Data.Proxy

…

usersAPI :: Proxy UsersAPI
usersAPI = Proxy :: Proxy UsersAPI

Now we can make our runnable Application like so (assuming we have a Postgres connection):

serve usersAPI (usersServer connString)

We’ll run this server from port 8000 by using the run function, again from Network.Wai. (See Github for a full list of imports). We’ll fetch our connection string, and then we’re good to go!

runServer :: IO ()
runServer = do
  pgInfo <- fetchPostgresConnection
  run 8000 (serve usersAPI (usersServer pgInfo))

Conclusion

The Servant library offers some truly awesome possibilities. We’re able to define a web API at the type level. We can then define handler functions using the parameters the endpoints expect. Servant handles all the work of marshalling back and forth between the HTTP request and the native Haskell types. It also ensures a match between the endpoints and the handler function types!

If you want to see even more of the possibilities that Servant offers, you should watch my talk from Bayhac. It goes through some more advanced concepts like authentication and client side functions. You can get the slides and all the code examples for that talk here.

If you’ve never tried Haskell before, there’s no time like the present to start! Download our Getting Started Checklist for some tools to help start your Haskell journey!

Read More
James Bowen James Bowen

Trouble with Databases? Persevere with Persistent!

Our recent series at Monday Morning Haskell focused on machine learning. In particular, we did a deep dive on the Haskell Tensor Flow library. While AI is huge area indeed, it doesn't account for the bulk of day-to-day work. To build even a basic production system, there are a multitude of simpler tasks. In our newest series, we’ll be learning a host of different libraries to perform these tasks!

In this first article we’ll discuss Persistent. Many libraries allow you to make a quick SQL call. But Persistent does much more than that. With Persistent, you can link your Haskell types to your database definition. You can also make type-safe queries to save yourself the hassle of decoding data. All in all, it's a very cool system.

All the code for this series will be on Github! To follow along with this article, take a look at the persistent branch.

Our Basic Type

We’ll start by considering a simple user type that looks like this:

data User = User
  { userName :: Text
  , userEmail :: Text
  , userAge :: Int
  , userOccupation :: Text
  }

Imagine we want to store objects of this type in an SQL database. We’ll first need to define the table to store our users. We could do this with a manual SQL command or through an editor, but regardless, the process will be error prone. The command would look something like this:

create table users (
  name varchar(100),
  email varchar(100),
  age bigint,
  occupation varchar(100)
)

When we do this, there's nothing linking our Haskell data type to the table structure. If we update the Haskell code, we have to remember to update the database. And this means writing another error-prone command.

From our Haskell program, we’ll also want to make SQL queries based on the structure of the user. We could write out these raw commands and execute them, but the same issues apply. This method is error prone and not at all type-safe. Persistent helps us solve these problems.

Persistent and Template Haskell

We can get these bonuses from Persistent without all that much extra code! To do this, we’re going to use Template Haskell (TH). We’ve seen TH once in the past when we were deriving lenses and prisms for different data types. On that occasion we noted a few pros and cons of TH. It does allow us to avoid writing some boilerplate code. But it will make our compile times longer as well. It will also make our code less accessible to inexperienced Haskellers. With lenses though, it only saved us a few dozen lines of total code. With Persistent, TH generates a lot more code, so the pros definitely outweigh the cons.

When we created lenses with TH, we used a simple declaration makeLenses. Here, we’ll do something a bit more complicated. We’ll use a language construct called a “quasi-quoter”. This is a block of code that follows some syntax designed by the programmer or in a library, rather than normal Haskell syntax. It is often used in libraries that do some sort of foreign function interface. We delimit a quasi-quoter by a combination of brackets and pipes. Here’s what the Template Haskell call looks like. The quasi-quoter is the final argument:

import qualified Database.Persist.TH as PTH

PTH.share [PTH.mkPersist PTH.sqlSettings, PTH.mkMigrate "migrateAll"] [PTH.persistLowerCase|

|]

The share function takes a list of settings and then the quasi-quoter itself. It then generates the necessary Template Haskell for our data schema. Within this section, we’ll define all the different types our database will use. We notate certain settings about how those types. In particular we specify sqlSettings, so everything we do here will focus on an SQL database. More importantly, we also create a migration function, migrateAll. After this Template Haskell gets compiled, this function will allow us to migrate our DB. This means it will create all our tables for us!

But before we see this in action, we need to re-define our user type. Instead of defining User in the normal Haskell way, we’re going to define it within the quasi-quoter. Note that this level of Template Haskell requires many compiler extensions. Here’s our definition:

{-# LANGUAGE TemplateHaskell            #-}
{-# LANGUAGE QuasiQuotes                #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE OverloadedStrings          #-}



PTH.share [PTH.mkPersist PTH.sqlSettings, PTH.mkMigrate "migrateAll"] [PTH.persistLowerCase|
  User sql=users
    name Text
    email Text
    age Int
    occupation Text
    UniqueEmail email
    deriving Show Read
|]

There are a lot of similarities to a normal data definition in Haskell. We’ve changed the formatting and reversed the order of the types and names. But you can still tell what’s going on. The field names are all there. We’ve still derived basic instances like we would in Haskell.

But we’ve also added some new directives. For instance, we’ve stated what the table name should be (by default it would be user, not users). We’ve also created a UniqueEmail constraint. This tells our database that each user has to have a unique email. The migration will handle creating all the necessary indices for this to work!

This Template Haskell will generate the normal Haskell data type for us. All fields will have the prefix user and will be camel-cased, as we specified. The compiler will also generate certain special instances for our type. These will enable us to use Persistent's type-safe query functions. Finally, this code generates lenses that we'll use as filters in our queries, as we'll see later.

Entities and Keys

Persistent also has a construct allowing us to handle database IDs. For each type we put in the schema, we’ll have a corresponding Entity type. An Entity refers to a row in our database, and it associates a database ID with the object itself. The database ID has the type SqlKey and is a wrapper around Int64. So the following would look like a valid entity:

import Database.Persist (Entity(..))

sampleUser :: Entity User
sampleUser = Entity (toSqlKey 1) $ User
  { userName = “admin”
  , userEmail = “admin@test.com”
  , userAge = 23
  , userOccupation = “System Administrator”
  }

This nice little abstraction that allows us to avoid muddling our user type with the database ID. This allows our other code to use a more pure User type.

The SqlPersistT Monad

So now that we have the basics of our schema, how do we actually interact with our database from Haskell code? As a specific example, we’ll be accessing a PostgresQL database. This requires the SqlPersistT monad. All the query functions return actions in this monad. The monad transformer has to live on top of a monad that is MonadIO, since we obviously need IO to run database queries.

If we’re trying to make a database query from a normal IO function, the first thing we need is a ConnectionString. This string encodes information about the location of the database. The connection string generally has 4-5 components. It has the host/IP address, the port, the database username, and the database name. So for instance if you’re running Postgres on your local machine, you might have something like:

{-# LANGUAGE OverloadedStrings #-}

import Database.Persist.Postgresql (ConnectionString)

connString :: ConnectionString
connString = “host=127.0.0.1 port=5432 user=postgres dbname=postgres password=password”

Now that we have the connection string, we’re set to call withPostgresqlConn. This function takes the string and then a function requiring a backend:

-- Also various constraints on the monad m
withPostgresqlConn :: (IsSqlBackend backend) => ConnectionString -> (backend -> m a) -> m a

The IsSqlBackend constraint forces us to use a type that conforms to Persistent’s guidelines. The SqlPersistT monad is only a synonym for ReaderT backend. So in general, the only thing we’ll do with this backend is use it as an argument to runReaderT. Once we’ve done this, we can pass any action within SqlPersistT as an argument to run.

import Control.Monad.Logger (runStdoutLoggingT)
import Database.Persist.Postgresql (ConnectionString, withPostgresqlConn, SqlPersistT)

…

runAction :: ConnectionString -> SqlPersistT a ->  IO a
runAction connectionString action = runStdoutLoggingT $ withPostgresqlConn connectionString $ \backend ->
  runReaderT action backend

Note we add in a call to runStdoutLoggingT so that our action can log its results, as Persistent expects. This is necessary whenever we use withPostgresqlConn. Here's how we would run our migration function:

migrateDB :: IO ()
migrateDB = runAction connString (runMigration migrateAll)

This will create the users table, perfectly to spec with our data definition!

Queries

Now let’s wrap up by taking a quick examination of the kinds of queries we can run. The first thing we could do is insert a new user into our database. For this, Persistent has the insert function. When we insert the user, we’ll get a key for that user as a result. Here’s the type signature for insert specified to our particular User type:

insert :: (MonadIO m) => User -> SqlPersistT m (Key User)

Then of course we can also do things in reverse. Suppose we have a key for our user and we want to get it out of the database. We’ll want the get function. Of course this might fail if there is no corresponding user in the database, so we need a Maybe.:

get :: (MonadIO m) => Key User -> SqlPersistT m (Maybe User)

We can use these functions for any type satisfying the PersistRecordBackend class. This is included for free when we use the template Haskell approach. So you can use these queries on any type that lives in your schema.

But SQL allows us to do much more than query with the key. Suppose we want to get all the users that meet certain criteria. We’ll want to use the selectList function, which replicates the behavior of the SQL SELECT command. It takes a couple different arguments for the different ways to run a selection. The two list types look a little complicated, but we’ll examine them in more detail:

selectList 
  :: PersistRecordBackend backend val 
  => [Filter val] 
  -> [SelectOpt val]
  -> SqlPersistT m [val]

As before, the PersistRecordBackend constraint is satisfied by any type in our TH schema. So we know our User type fits. So let’s examine the first argument. It provides a list of different filters that will determine which elements we fetch. For instance, suppose we want all users who are younger than 25 and whose occupation is “Teacher”. Remember the lenses I mentioned that get generated? We’ll create two different filters on this by using these lenses.

selectYoungTeachers :: (MonadIO m, MonadLogger m) => SqlPersistT m [User]
selectYoungTeachers = select [UserAge <. 25, UserOccupation ==. “Teacher”] []

We use the UserAge lens and the UserOccupation lens to choose the fields to filter on. We use a "less-than" operator to state that the age must be smaller than 25. Similarly, we use the ==. operator to match on the occupation. Then we provide an empty list of SelectOpts.

The second list of selection operations provides some other features we might expect in a select statement. First, we can provide an ordering on our returned data. We’ll also use the generated lenses here. For instance, Asc UserEmail will order our list by email. Here's an ordered query where we also limit ourselves to 100 entries. Here’s what that query would look like:

selectYoungTeachers’ :: (MonadIO m) => SqlPersistT m [User]
selectYoungTeachers’ = selectList [UserAge <=. 25, UserOccupation ==. “Teacher”] [Asc UserEmail]

The other types of SelectOpts include limits and offsets. For instance, we can further modify this query to exclude the first 5 users (as ordered by email) and then limit our selection to 100:

selectYoungTeachers' :: (MonadIO m) => SqlPersistT m [Entity User]
selectYoungTeachers' = selectList
  [UserAge <. 25, UserOccupation ==. "Teacher"] [Asc UserEmail, OffsetBy 5, LimitTo 100]

And that’s all there is to making queries that are type-safe and sensible. We know we’re actually filtering on values that make sense for our types. We don’t have to worry about typos ruining our code at runtime.

Conclusion

Persistent gives us some excellent tools for interacting with databases from Haskell. The Template Haskell mechanisms generate a lot of boilerplate code that helps us. For instance, we can migrate our database to the create the correct tables for our Haskell types. We also can perform queries that filter results in a type-safe way. All in all, it’s a fantastic experience.

Never tried Haskell? Do other languages frustrate you with runtime errors whenever you try to run SQL queries? You should give Haskell a try! Download our Getting Started Checklist for some tools that will help you!

If you’re a little familiar with Haskell but aren’t sure how to incorporate libraries like Persistent, you should check out our Stack mini-course. It’ll walk you through the basics of making a simple Haskell program using Stack.

Read More
James Bowen James Bowen

Grenade! Dependently Typed Neural Networks

In the last couple weeks we explored one of the most complex topics I’ve presented on this blog. We examined potential runtime failures that can occur when using Tensor Flow. These included mismatched dimensions and missing placeholders. In an ideal world, we would catch these issues at compile time instead. At its current stage, the Haskell Tensor Flow library doesn’t support that. But we demonstrated that it is possible to add a layer to do this by using dependent types.

Now, I’m still very much of a novice at dependent types, so the solutions I presented were rather clunky. This week I'll show a better example of this concept from a different library. The Grenade library uses dependent types everywhere. It allows us to build verifiably-valid neural networks with extreme concision. So let’s dive in and see what it’s all about!

Shapes and Layers

The first thing to learn with this library is the two concepts of Shapes and Layers. Shapes are best compared to tensors from Tensor Flow, except that they exist at the type level. In Tensor Flow we could build tensors with arbitrary dimensions. Grenade currently only supports up to three dimensions. So the different shape types either start with D1, D2, or D3, depending on the dimensionality of the shape. Then each of these type constructors take a set of natural number parameters. So the following are all valid “Shape” types within Grenade:

D1 5
D2 4 12
D3 8 10 2

The first represents a vector with 5 elements. The second represents a matrix with 4 rows and 12 columns. And the third represents an 8x10x2 matrix (or tensor, if you like). The different numbers represent those values at the type level, not the term level. If this seems confusing, here’s a good tutorial that goes into more depth about the basics of dependent types. The most important idea is that something of type D1 5 can only have 5 elements. A vector of 4 or 6 elements will not type-check.

So now that we know about shapes, let’s examine layers. Layers describe relationships between our shapes. They encapsulate the transformations that happen on our data. The following are all valid layer types:

Relu
FullyConnected 10 20
Convolution 1 10 5 5 1 1

The layer Relu describes a layer that takes in data of any kind of shape and outputs the same shape. In between, it applies the relu activation function to the input data. Since it doesn’t change the shape, it doesn’t need any parameters.

A FullyConnected layer represents the canonical layer of a neural network. It has two parameters, one for the number of input neurons and one for the number of output neurons. In this case, the layer will take 10 inputs and produce 20 outputs.

A Convolution layer represents a 2D convolution like we saw with our MNIST network. This particular example has 1 input feature, 10 output features, uses a 5x5 patch size, and a 1x1 patch offset.

Describing a Network

Now that we have a basic grasp on shapes and layers, we can see how they fit together to create a full network. A network type has two type parameters. The second parameter is a list of the shapes that our data takes at any given point throughout the network. The first parameter is a list of the layers representing the transformations on the data. So let’s say we wanted to describe a very simple network. It will take 4 inputs and produce 10 outputs using a fully connected layer. Then it will perform an Relu activation. This network looks like this:

type SimpleNetwork = Network
  ‘[FullyConnected 4 10, Relu]
  ‘[ ‘D1 4, ‘D1 10, ‘D1 10]

The apostrophes in front of the lists and D1 terms indicated that these are promoted constructors. So they are types instead of terms. To “read” this type, we start with the first data format. We go to each successive data format by applying the transformation layer. So for instance we start with a 4-vector, and transform it into a 10-vector with a fully-connected layer. Then we transform that 10-vector into another 10-vector by applying relu. That’s all there is to it! We could apply another FullyConnected layer onto this that will have 3 outputs like so:

type SimpleNetwork = Network
  ‘[FullyConnected 4 10, Relu, FullyConnected 10 3]
  ‘[ ‘D1 4, ‘D1 10, ‘D1 10, `D1 3]

Let's look at MNIST to see a more complicated example. We'll start with a 28x28 image of data. Then we’ll perform the convolution layer I mentioned above. This gives us a 3-dimensional tensor of size 24x24x10. Then we can perform 2x2 max pooling on this, resulting in a 12x12x10 tensor. Finally, we can apply an Relu layer, which keeps it at the same size:

type MNISTStart = MNISTStart
  ‘[Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu]
  ‘[D2 28 28, D3 24 24 10, D3 12 12 10, D3 12 12 10]

Here’s what a full MNIST example might look like (per the README on the library’s Github page):

type MNIST = Network
    '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu
     , Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu
     , FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
    '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10
     , 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256
     , 'D1 80, 'D1 80, 'D1 10, 'D1 10]

This is a much simpler and more concise description of our network than we can get in Tensor Flow! Let’s examine the ways in which the library uses dependent types to its advantage.

The Magic of Dependent Types

Describing our network as a type seems like a strange idea if you’ve never used dependent types before. But it gives us a couple great perks!

The first major win we get is that it is very easy to generate the starting values of our network. Since it has a specific type, we can let type inference guide us! We don’t need any term level code that is specific to the shape of our network. All we need to do is attach the type signature and call randomNetwork!

randomSimple :: MonadRandom m => m SimpleNetwork
randomSimple = randomNetwork

This will give us all the initial values we need, so we can get going!

The second (and more important) win is that we can’t build an invalid network! Suppose we try to take our simple network and somehow format it incorrectly. For instance, we could say that instead of the input shape being of size 4, it’s of size 7:

type SimpleNetwork = Network
  ‘[FullyConnected 4 10, Relu, FullyConnected 10 3]
  ‘[ ‘D1 7, ‘D1 10, ‘D1 10, `D1 3]
-- ^^ Notice this 7

This will result in a compile error, since there is a mismatch between the layers. The first layer expects an input of 4, but the first data format is of length 7!

Could not deduce (Layer (FullyConnected 4 10) ('D1 7) ('D1 10))
        arising from a use of ‘randomNetwork’
      from the context: MonadRandom m
        bound by the type signature for:
                   randomSimple :: MonadRandom m => m SimpleNetwork
        at src/IrisGrenade.hs:29:1-48

In other words, it notices that the chain from D1 7 to D1 10 using a FullyConnected 4 10 layer is invalid. So it doesn’t let us make this network. The same thing would happen if we made the layers themselves invalid. For instance, we could make the output and input of the two fully-connected layers not match up:

-- We changed the second to take 20 as the number of input elements.
type SimpleNetwork = Network 
  '[FullyConnected 4 10, Relu, FullyConnected 20 3]
  '[ 'D1 4, 'D1 10, 'D1 20, 'D1 3]

…

/Users/jamesbowen/HTensor/src/IrisGrenade.hs:30:16: error:
    • Could not deduce (Layer (FullyConnected 20 3) ('D1 10) ('D1 3))
        arising from a use of ‘randomNetwork’
      from the context: MonadRandom m
        bound by the type signature for:
                   randomSimple :: MonadRandom m => m SimpleNetwork
        at src/IrisGrenade.hs:29:1-48

So Grenade makes our program much safer by providing compile time guarantees about our network's validity. Runtime errors due to dimensionality are impossible!

Training the Network on Iris

Now let’s do a quick run-through of how we actually train this neural network. Readers with a keen eye may have noticed that the SimpleNetwork we’ve built is the same network we used to train the Iris data set. So we’ll do a training run there, using the following steps:

  1. Write the network type and generate a random network from it
  2. Read our input data into a format that Grenade uses
  3. Write a function to run a training iteration.
  4. Run it!

1. Write the Network type and Generate Network

So we've already done this first step for the most part. We’ll adjust the names a little bit though. Note that I’ll include the imports list as an appendix to the post. Also, the code is on the grenade branch of my Haskell Tensor Flow repository in IrisGrenade.hs!

type IrisNetwork = Network 
  '[FullyConnected 4 10, Relu, FullyConnected 10 3]
  '[ 'D1 4, 'D1 10, 'D1 10, 'D1 3]

randomIris :: MonadRandom m => m IrisNetwork
randomIris = randomNetwork

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = do
  initialNetwork <- randomIris
  ...

2. Take in our Input Data

We’ll make use of the readIrisFromFile function we used back when we first did Iris. Then we'll make a dependent type called IrisRow, which uses the S type. This S type is a container for a shape. We want our input data to use D1 4 for the 4 input features. Then our output data should use D1 3 for the three possible categories.

-- Dependent type on the dimensions of the row
type IrisRow = (S ('D1 4), S ('D1 3))

If we have malformed data, the types will not match up, so we’ll need to return a Maybe to ensure this succeeds. Note that we normalize the data by dividing by 8. This puts all the data between 0 and 1 and makes for better training results. Here's how we parse the data:

parseRecord :: IrisRecord -> Maybe IrisRow
parseRecord record = case (input, output) of
  (Just i, Just o) -> Just (i, o)
  _ -> Nothing
  where
    input = fromStorable $ VS.fromList $ float2Double <$>
      [ field1 record / 8.0, field2 record / 8.0, field3 record / 8.0, field4 record / 8.0]
    output = oneHot (fromIntegral $ label record)

Then we incorporate these into our main function:

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = do
  initialNetwork <- randomIris
  trainingRecords <- readIrisFromFile trainingFile
  testRecords <- readIrisFromFile testingFile

  let trainingData = mapMaybe parseRecord (V.toList trainingRecords)
  let testData = mapMaybe parseRecord (V.toList testRecords)

  -- Catch if any were parsed as Nothing
  if length trainingData /= length trainingRecords || length testData /= length testRecords
    then putStrLn "Hmmm there were some problems parsing the data"
    else …

3. Write a Function to Train the Input Data

This is a multi-step process. First we’ll establish our learning parameters. We'll also write a function that will allow us to call the train function on a particular row element:

learningParams :: LearningParameters
learningParams = LearningParameters 0.01 0.9 0.0005

-- Train the network!
trainRow :: LearningParameters -> IrisNetwork -> IrisRow -> IrisNetwork
trainRow lp network (input, output) = train lp network input output

Next we’ll write two more helper functions that will help us test our results. The first will take the network and a test row. It will transform it into the predicted output and the actual output of the network. The second function will take these outputs and reverse the oneHot process to get the label out (0, 1, or 2).

-- Takes a test row, returns predicted output and actual output from the network.
testRow :: IrisNetwork -> IrisRow -> (S ('D1 3), S ('D1 3))
testRow net (rowInput, predictedOutput) = (predictedOutput, runNet net rowInput)

-- Goes from probability output vector to label
getLabels :: (S ('D1 3), S ('D1 3)) -> (Int, Int)
getLabels (S1D predictedLabel, S1D actualOutput) = 
  (maxIndex (extract predictedLabel), maxIndex (extract actualOutput))

Finally we’ll write a function that will take our training data, test data, the network, and an iteration number. It will return the newly trained network, and log some results about how we’re doing. We’ll first take only a sample of our training data and adjust our parameters so that learning gets slower. Then we'll train the network by folding over the sampled data.

run :: [IrisRow] -> [IrisRow] -> IrisNetwork -> Int -> IO IrisNetwork
run trainData testData network iterationNum = do
  sampledRecords <- V.toList <$> chooseRandomRecords (V.fromList trainData)
  -- Slowly drop the learning rate
  let revisedParams = learningParams 
        { learningRate = learningRate learningParams * 0.99 ^ iterationNum}
  let newNetwork = foldl' (trainRow revisedParams) network sampledRecords
  ....

Then we’ll wrap up the function by looking at our test data, and seeing how much we got right!

run :: [IrisRow] -> [IrisRow] -> IrisNetwork -> Int -> IO IrisNetwork
    run trainData testData network iterationNum = do
      sampledRecords <- V.toList <$> chooseRandomRecords (V.fromList trainData)
      -- Slowly drop the learning rate
      let revisedParams = learningParams 
            { learningRate = learningRate learningParams * 0.99 ^ iterationNum}
      let newNetwork = foldl' (trainRow revisedParams) network sampledRecords
      let labelVectors = fmap (testRow newNetwork) testData
      let labelValues = fmap getLabels labelVectors
      let total = length labelValues
      let correctEntries = length $ filter ((==) <$> fst <*> snd) labelValues
      putStrLn $ "Iteration: " ++ show iterationNum
      putStrLn $ show correctEntries ++ " correct out of: " ++ show total
      return newNetwork

4. Run it!

We’ll call this now from our main function, iterating 100 times, and we’re done!

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = do
 ...
  if length trainingData /= length trainingRecords || length testData /= length testRecords
    then putStrLn "Hmmm there were some problems parsing the data"
    else foldM_ (run trainingData testData) initialNetwork [1..100]

Comparing to Tensor Flow

So now that we’ve looked at a different library, we can consider how it stacks up against Tensor Flow. So first, the advantages. Grenade's main advantage is that it provides dependent type facilities. This means it is more difficult to write incorrect programs. The basic networks you build are guaranteed to have the correct dimensionality. Additionally, it does not use a “placeholders” system, so you can avoid those kinds of errors too. This means you're likely to have fewer runtime bugs using Grenade.

Concision is another major strong point. The training code got a bit involved when translating our data into Grenade's format. But it’s no more complicated than Tensor Flow. When it comes down to the exact definition of the network itself, we do this in only a few lines with Grenade. It’s complicated to understand what those lines mean if you are new to dependent types. But after seeing a few simple examples you should be able to follow the general pattern.

Of course, none of this means that Tensor Flow is without its advantages. As we saw a couple weeks ago, it is not too difficult to add very thorough logging to your Tensor Flow program. The Tensor Board application will then give you excellent visualizations of this data. It is somewhat more difficult to get intermediate log results with Grenade. There is not too much transparency (that I have found at least) into the inner values of the network. The network types are composable though. So it is possible to get intermediate steps of your operation. But if you break your network into different types and stitch them together, you will remove some of the concision of the network.

Also, Tensor Flow also has a much richer ecosystem of machine learning tools to access. Grenade is still limited to a subset of the most common machine learning layers, like convolution and max pooling. Tensor Flow’s API allows approaches like support vector machines and linear models. So Tensor Flow offers you more options.

One question I may explore in a future article would be to compare the performance of the two libraries. My suspicion is that Tensor Flow is faster due to how it gets all its math down to the C-level. But I’m not too familiar yet with HMatrix (which Grenade depends on for its math) and its efficiency. So I could definitely be wrong.

Conclusion

Grenade provides some truly awesome facilities for building a concise neural network. A Grenade program can demonstrate at compile time that the network is well formed. It also allows an incredibly concise way to define what layers your neural network has. It doesn’t have the Google level support that Tensor Flow does. So it lacks many cool features like logging and visualizations. But it is quite a neat library for its scope. One thing I haven’t mentioned is its mechanics for Generative/Adversarial networks. I’d definitely like to try that out soon!

Grenade is a simpler library to incorporate into Stack compared to Tensor Flow. If you want to compare the two, you should check out our Haskell Tensor Flow guide so you can install TF and get started!

If you’ve never written a line of Haskell before, never fear! Download our Getting Started Checklist for some free resources to start your Haskell education!

Appendix: Compiler Extensions and Imports

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GADTs #-}

import           Control.Monad (foldM_)
import           Control.Monad.Random (MonadRandom)
import           Control.Monad.IO.Class (liftIO)
import           Data.Foldable (foldl')
import           Data.Maybe (mapMaybe)
import qualified Data.Vector.Storable as VS
import qualified Data.Vector as V
import           GHC.Float (float2Double)
import           Grenade
import           Grenade.Core.LearningParameters (LearningParameters(..))
import           Grenade.Core.Shape (fromStorable)
import           Grenade.Utils.OneHot (oneHot)
import           Numeric.LinearAlgebra (maxIndex)
import           Numeric.LinearAlgebra.Static (extract)

import           Processing (IrisRecord(..), readIrisFromFile, chooseRandomRecords)
Read More
James Bowen James Bowen

Checking it's all in Place: Placeholders and Dependent Types

Last week we dove into the world of dependent types. We linked tensors with their shapes at the type level. This gave our program some extra type safety and allowed us to avoid certain runtime errors.

This week, we’re going to solve another runtime conundrum: missing placeholders. We’ll add some more dependent type machinery to ensure we've plugged in all the necessary placeholders! But we’ll see this is not as straightforward as shapes.

To follow along with the code in this article, take a look at this branch on my Haskell Tensor Flow Github repository. All the code for this article is in DepShape.hs. As usual, I've listed the necessary compiler extensions and imports at the bottom of this article. If you want to run the code yourself, you'll have to get Haskell and Tensor Flow running first. Take a look at our Haskell Tensor Flow guide for that!

Now to start, let’s remind ourselves what placeholders are in Tensor Flow and how we use them.

Placeholder Review

Placeholders represent tensors that can have different values on different application runs. This is often the case when we’re training on different samples of data. Here’s our very simple example in Python. We’ll create a couple placeholder tensors by providing their shapes and no values. Then when we actually run the session, we’ll provide a value for each of those tensors.

node1 = tf.placeholder(tf.float32)
node2 = tf.placeholder(tf.float32)
adderNode = tf.add(node1, node2)
sess = tf.Session()
result1 = sess.run(adderNode, {node1: 3, node2: 4.5 })

The weakness here is that there’s nothing forcing us to provide values for those tensors! We could try running our program without them and we’ll get a runtime crash:

...
sess = tf.Session()
result1 = sess.run(adderNode)
print(result1)
…

Terminal Output:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype float
   [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Unfortunately, the Haskell Tensor Flow library doesn’t actually do any better here. When we want to fill in placeholders, we provide a list of “feeds”. But our program will still compile even if we pass an empty list! We’ll encounter similar runtime errors:

(node1 :: Tensor Value Float) <- placeholder [1]
(node2 :: Tensor Value Float) <- placeholder [1]
let adderNode = node1 `add` node2
let runStep = \node1Feed node2Feed -> runWithFeeds [] adderNode
runStep (encodeTensorData [1] input1) (encodeTensorData [1] input2)
…

Terminal Output:

TensorFlowException TF_INVALID_ARGUMENT "You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [1]\n\t [[Node: Placeholder_1 = Placeholder[dtype=DT_FLOAT, shape=[1], _device=\"/job:localhost/replica:0/task:0/cpu:0\"]()]]"

In the Iris and MNIST examples, we bury the call to runWithFeeds within our neural network API. We only provide a Model object. This model object forces us to provide the expected input and output tensors. So anyone using our model wouldn't make a manual runWithFeeds call.

data Model = Model
  { train :: TensorData Float
          -> TensorData Int64
          -> Session ()
  , errorRate :: TensorData Float
              -> TensorData Int64
              -> SummaryTensor
              -> Session (Float, ByteString)
  }

This isn’t a bad solution! But it’s interesting to see how we can push the envelope with dependent types, so let’s try that!

Adding More “Safe” Types

The first step we’ll take is to augment Tensor Flow’s TensorData type. We’ll want it to have shape information like SafeTensor and SafeShape. But we’ll also attach a name to each piece of data. This will allow us to identify which tensor to substitute the data in for. At the type level, we refer to this name as a Symbol.

data SafeTensorData a (n :: Symbol) (s :: [Nat]) where
  SafeTensorData :: (TensorType a) => TensorData a -> SafeTensorData a n s

Next, we’ll need to make some changes to our SafeTensor type. First, each SafeTensor will get a new type parameter. This parameter refers to a mapping of names (symbols) to shapes (which are still lists of naturals). We'll call this a placeholder list. So each tensor will have type-level information for the placeholders it depends on. Each different placeholder has a name and a shape.

data SafeTensor v a (s :: [Nat]) (p :: [(Symbol, [Nat])]) where
  SafeTensor :: (TensorType a) => Tensor v a -> SafeTensor v a s p

Now, recall when we substituted for placeholders, we used a list of feeds. But this list had no information about the names or dimensions of its feeds. Let's create a new type containing the different elements we need for our feeds. It should also contain the correct type information about the placeholder list. The first step of to define the type so that it has the list of placeholders it contains, like the SafeTensor.

data FeedList (pl :: [(Symbol, [Nat])]) where

This structure will look like a linked list, like our SafeShape. Thus we’ll start by defining an “empty” constructor:

data FeedList (pl :: [(Symbol, [Nat])]) where
  EmptyFeedList :: FeedList '[]

Now we’ll add a “Cons”-like constructor by creating yet another type operator :--:. Each “piece” of our linked list will contain two different items. First, the tensor we are substituting for. Next, it will have the data we’ll be using for the substitution. We can use type parameters to force their shapes and data types to match. Then we need the resulting placeholder type. We have to append the type-tuple containing the symbol and shape to the previous list. This completes our definition.

data FeedList (pl :: [(Symbol, [Nat])]) where
  EmptyFeedList :: FeedList '[]
  (:--:) :: (KnownSymbol n)
    => (SafeTensor Value a s p, SafeTensorData a n s) 
    -> FeedList pl
    -> FeedList ( '(n, s) ': pl)

infixr 5 :--:

Note that we force the tensor to be a Value tensor. We can only substitute data for rendered tensors, hence this restriction. Let's add a quick safeRender so we can render our SafeTensor items.

safeRender :: (MonadBuild m) => SafeTensor Build a s pl -> m (SafeTensor Value a s pl)
safeRender (SafeTensor t1) = do
  t2 <- render t1
  return $ SafeTensor t2

Making a Placeholder

Now we can write our safePlaceholder function. We’ll add a KnownSymbol as a type constraint. Then we’ll take a SafeShape to give ourselves the type information for the shape. The result is a new tensor that maps the symbol and the shape in the placeholder list.

safePlaceholder :: (MonadBuild m, TensorType a, KnownSymbol sym) => 
  SafeShape s -> m (SafeTensor Value a s '[ '(sym, s)])
safePlaceholder shp = do
  pl <- placeholder (toShape shp)
  return $ SafeTensor pl

This looks a little crazy, and it kind’ve is! But we’ve now created a tensor that stores its own placeholder information at the type level!

Updating Old Code

Now that we’ve done this, we’re also going to have to update some of our older code. The first part of this is pretty straightforward. We’ll need to change safeConstant so that it has the type information. It will have an empty list for the placeholders.

safeConstant :: (TensorType a, ShapeProduct s ~ n) => 
  Vector n a -> SafeShape s -> SafeTensor Build a s '[]
safeConstant elems shp = SafeTensor (constant (toShape shp) (toList elems))

Our mathematical operations will be a bit more tricky though. Consider adding two arbitrary tensors. They may share placeholder dependencies but may not. What should be the placeholder type for the resulting tensor? Obviously the union of the two placeholder maps of the input tensors! Luckily for us, we can use Union from the type-list library to represent this concept.

safeAdd :: (TensorType a, a /= Bool, TensorKind v)
  => SafeTensor v a s p1
  -> SafeTensor v a s p2
  -> SafeTensor Build a s (Union p1 p2)
safeAdd (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `add` t2)

We’ll make the same update with matrix multiplication:

safeMatMul :: (TensorType a, a /= Bool, a /= Int8, a /= Int16,
               a /= Int64, a /= Word8, a /= ByteString, TensorKind v)
   => SafeTensor v a '[i,n] p1 -> SafeTensor v a '[n,o] p2 -> SafeTensor Build a '[i,o] (Union p1 p2)
safeMatMul (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `matMul` t2)

Running with Placeholders

Now we have all the information we need to write our safeRun function. This will take a SafeTensor, and it will also take a FeedList with the same placeholder type. Remember, a FeedList contains a series of SafeTensorData items. They must match up symbol-for-symbol and shape-for-shape with the placeholders within the SafeTensor. Let’s look at the type signature:

safeRun :: (TensorType a, Fetchable (Tensor v a) r) =>
  FeedList pl -> SafeTensor v a s pl -> Session r

The Fetchable constraint enforces that we can actually get the “result” r out of our tensor. For instance, we can "fetch" a vector of floats out of a tensor that uses Float as its underlying value.

We’ll next define a tail-recursive helper function to build the vanilla “list of feeds” out of our FeedList. Through pattern matching, we can pick out the tensor to substitute for and the data we’re using. We can combine these into a feed and append to the growing list:

safeRun = ...
  where
    buildFeedList :: FeedList ss -> [Feed] -> [Feed]
    buildFeedList EmptyFeedList accum = accum
    buildFeedList ((SafeTensor tensor_, SafeTensorData data_) :--: rest) accum = 
      buildFeedList rest ((feed tensor_ data_) : accum)

Now all we have to do to finish up is call the normal runWithFeeds function with the list we’ve created!

safeRun :: (TensorType a, Fetchable (Tensor v a) r) =>
  FeedList pl -> SafeTensor v a s pl -> Session r
safeRun feeds (SafeTensor finalTensor) = runWithFeeds (buildFeedList feeds []) finalTensor
  where
  ...

And here’s what it looks like to use this in practice with our simple example. Notice the type signatures do get a little cumbersome. The signatures we place on the initial placeholder tensors are necessary. Otherwise the compiler wouldn't know what label we're giving them! The signature containing the union of the types is unnecessary. We can remove it if we want and let type inference do its work.

main3 :: IO (VN.Vector Float)
main3 = runSession $ do
  let (shape1 :: SafeShape '[2,2]) = fromJust $ fromShape (Shape [2,2])
  (a :: SafeTensor Value Float '[2,2] '[ '("a", '[2,2])]) <- safePlaceholder shape1
  (b :: SafeTensor Value Float '[2,2] '[ '("b", '[2,2])] ) <- safePlaceholder shape1
  let result = a `safeAdd` b
  (result_ :: SafeTensor Value Float '[2,2] '[ '("b", '[2,2]), '("a", '[2,2])]) <- safeRender result
  let (feedA :: Vector 4 Float) = fromJust $ fromList [1,2,3,4]
  let (feedB :: Vector 4 Float) = fromJust $ fromList [5,6,7,8]
  let fullFeedList = (b, safeEncodeTensorData shape1 feedB) :--:
                     (a, safeEncodeTensorData shape1 feedA) :--:
                     EmptyFeedList
  safeRun fullFeedList result_

{- It runs!
[6.0,8.0,10.0,12.0]
-}

Now suppose we make some mistakes with our types. Here we’ll take out the “A” feed from our feed list:

-- Let’s take out Feed A!
main = …
  let fullFeedList = (b, safeEncodeTensorData shape1 feedB) :--:
                     EmptyFeedList
  safeRun fullFeedList result_

{- Compiler Error!
• Couldn't match type ‘'['("a", '[2, 2])]’ with ‘'[]’
      Expected type: SafeTensor Value Float '[2, 2] '['("b", '[2, 2])]
        Actual type: SafeTensor
                       Value Float '[2, 2] '['("b", '[2, 2]), '("a", '[2, 2])]
-}

Here’s what happens when we try to substitute a vector with the wrong size. It will identify that we have the wrong number of elements!

main = …
  -- Wrong Size!
  let (feedA :: Vector 8 Float) = fromJust $ fromList [1,2,3,4,5,6,7,8]
  let (feedB :: Vector 4 Float) = fromJust $ fromList [5,6,7,8]
  let fullFeedList = (b, safeEncodeTensorData shape1 feedB) :--:
                     (a, safeEncodeTensorData shape1 feedA) :--:
                     EmptyFeedList
  safeRun fullFeedList result_

{- Compiler Error!
Couldn't match type ‘4’ with ‘8’
        arising from a use of ‘safeEncodeTensorData’
-}

Conclusion: Pros and Cons

So let’s take a step back and look at what we’ve constructed here. We’ve managed to provide ourselves with some pretty cool compile time guarantees. We’ve also added de-facto documentation to our code. Anyone familiar with the codebase can tell at a glance what placeholders we need for each tensor. It’s a lot harder now to write incorrect code. There are still error conditions of course. But if we’re smart we can write our code to deal with these all upfront. That way we can fail gracefully instead of throwing a random run-time crash somewhere.

But there are drawbacks. Imagine being a Haskell novice and walking into this codebase. You’ll have no real clue what’s going on (I wouldn’t have 2 months ago). The types are very cumbersome after a while, so continuing to write them down gets very tedious. Though as I mentioned, type inference can deal with a lot of that. But if you don’t track them, the type union can be finicky about the ordering of your placeholders. We could fix this with another type family though.

All these factors could present a real drag on development. But then again, tracking down run-time errors can also do this. Tensor Flow’s error messages can still be a little cryptic. This can make it hard to find root causes.

Since I’m still a novice with dependent types, this code was a little messy. Next week we’ll take a look at a more polished library that uses dependent types for neural networks. We’ll see how the Grenade library allows us to specify a learning system in just a few lines of code!

If you’re new to Haskell, I hope none of this dependent type madness scared you! The language is much easier than these last couple posts make it seem! Try it out, and download our Getting Started Checklist. It'll give you some instructions and tools to help you learn!

If you’re an experienced Haskeller and want to try out Tensor Flow, download our Tensor Flow Guide! It will walk you through incorporating the library into a Stack project!

Appendix: Compiler Extensions and Imports

{-# LANGUAGE GADTs                #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE UndecidableInstances #-}

import           Data.ByteString (ByteString)
import           Data.Int (Int64, Int8, Int16)
import           Data.Maybe (fromJust)
import           Data.Proxy (Proxy(..))
import           Data.Type.List (Union)
import qualified Data.Vector as VN
import           Data.Vector.Sized (Vector, toList, fromList)
import           Data.Word (Word8)
import           GHC.TypeLits (Nat, KnownNat, natVal)
import           GHC.TypeLits

import           TensorFlow.Core
import           TensorFlow.Core (Shape(..), TensorType, Tensor, Build)
import           TensorFlow.Ops (constant, add, matMul, placeholder)
import           TensorFlow.Session (runSession, run)
import           TensorFlow.Tensor (TensorKind)
Read More
James Bowen James Bowen

Deep Learning and Deep Types: Tensor Flow and Dependent Types

In the introduction to this series, one primary point I made was that Haskell is a safe language. There are a lot of errors we will catch at compile time, rather than runtime. Runtime errors can often be catastrophic to a system, so being able to reduce these is paramount. This is especially true when programming an autonomous car or drone. These objects will be out in the real world where they can hurt people if they malfunction.

So let’s take a look back at some of the code we’ve written over the last 3 or 4 weeks. Is it actually any safer? We’ll find the answer is, well, not so much. It's hard to verify certain properties about code. But the facilities for making this code safer do exist in Haskell! In the next couple articles we'll do some serious hacking with dependent types. We'll be able to prove some of these difficult properties of AI programs at compile time!

The next three articles will focus on dependent type programming. This is a difficult topic, so don’t worry if you can’t follow all the code examples completely. The main idea of making our machine learning code safer is what’s important! So without further ado, let’s rewind to the beginning to see where runtime issues can appear.

If you want to play with this code yourself, check out the dependent shapes branch on my Github repository! All the code for this article is in DepShape.hs Though if you want to get the code to run, you'll probably also need to get Haskell Tensor Flow working. Download our Haskell Tensor Flow Guide for instructions on that!

Issues with Python

Python, as an interpreted language, is definitely subject to runtime bugs. As I was first learning Tensor Flow, I came across a lot of these that were quite common. The two that stood out to me most were placeholder failures and dimension mismatches. For instance, let’s think back to one of the first examples. Our code will have a couple of placeholders, and we submit values for those when we run the session:

node1 = tf.placeholder(tf.float32)
node2 = tf.placeholder(tf.float32)
adderNode = tf.add(node1, node2)
sess = tf.Session()
result1 = sess.run(adderNode, {node1: 3, node2: 4.5 })

But there’s nothing stopping us from trying to run the session without submitting values. This will result in a runtime crash:

...
sess = tf.Session()
result1 = sess.run(adderNode)
print(result1)
…

Terminal Output:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype float
   [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Another issue that came up from time to time was dimension mismatches. Certain operations need certain relationships between the dimensions of the tensors. For instance, you can’t add two vectors with different lengths:

node1 = tf.constant([3.0, 4.0, 5.0], dtype=tf.float32)
node2 = tf.constant([4.0, 16.0], dtype=tf.float32)
additionNode = tf.add(node1, node2)

sess = tf.Session()
result = sess.run(additionNode)
print(result)

…

Terminal Output:

ValueError: Dimensions must be equal, but are 3 and 2 for 'Add' (op: 'Add') with input shapes: [3], [2].

Again, we get a runtime crash. These seem like the kinds of problems we can solve at compile time.

Does Haskell Solve these Issues?

But anyone who takes a close look at the Haskell code I’ve written so far can see that it doesn’t solve these issues! Here’s a review of our basic placeholder example:

runPlaceholder :: Vector Float -> Vector Float -> IO (Vector Float)
runPlaceholder input1 input2 = runSession $ do
  (node1 :: Tensor Value Float) <- placeholder [1]
  (node2 :: Tensor Value Float) <- placeholder [1]
  let adderNode = node1 `add` node2
  let runStep = \node1Feed node2Feed -> runWithFeeds 
        [ feed node1 node1Feed
        , feed node2 node2Feed
        ] 
        adderNode
  runStep (encodeTensorData [1] input1) (encodeTensorData [1] input2)

Notice how the runWithFeeds function takes a list of Feed objects. The code would still compile fine if we supplied the empty list. Then it would face a fate no better than our Python code:

…
let runStep = \node1Feed node2Feed -> runWithFeeds [] adderNode
…

Terminal Output:

TensorFlowException TF_INVALID_ARGUMENT "You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [1]\n\t [[Node: Placeholder_1 = Placeholder[dtype=DT_FLOAT, shape=[1], _device=\"/job:localhost/replica:0/task:0/cpu:0\"]()]]"

For the second example of dimensionality, we can also make this mistake in Haskell. The following code compiles and will crash at runtime:

runSimple :: IO (Vector Float)
runSimple = runSession $ do
  let node1 = constant [3] [3 :: Float, 4, 5]
  let node2 = constant [2] [4 :: Float, 5]
  let additionNode = node1 `add` node2
  run additionNode
…

Terminal Output:
TensorFlowException TF_INVALID_ARGUMENT "Incompatible shapes: [3] vs. [2]\n\t [[Node: Add_2 = Add[T=DT_FLOAT, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](Const_0, Const_1)]]"

At an even more basic level, we don’t even have to tell the truth about the shape of our vectors! We can give a bogus shape value and it will still compile!

let node1 = constant [3, 2, 3] [3 :: Float, 4, 5]
…

Terminal Output:
invalid tensor length: expected 18 got 3
CallStack (from HasCallStack):
  error, called at src/TensorFlow/Ops.hs:299:23 in tensorflow-ops-0.1.0.0-EWsy8DQdciaL8o6yb2fUKR:TensorFlow.Ops

Can we do better?

Now, we did do some things right. Let's think back to our Model type when we made neural networks.

data Model = Model
  { train :: TensorData Float
          -> TensorData Int64
          -> Session ()
  , errorRate :: TensorData Float
              -> TensorData Int64
              -> SummaryTensor
              -> Session (Float, ByteString)
  }

We exposed our training step as a function. This function forced the user to supply both of the tensors for the placeholders. This is good, but doesn't protect us from dimension issues.

When trying to solve these, we could write wrappers around every operation. Functions like add and matMul could return Maybe values. But this would be clunky. We could take this same step in Python. Granted, monads would allow the Haskell version to compose better. But it would be nicer if we could check our errors all at once, up front.

If we’re willing to dig quite a bit deeper, we can solve these problems! In the rest of this post, we’ll explore using dependent types to ensure dimensions are always correct. Getting placeholders right turns out to be a little more complicated though! So we’ll save that for next week’s post.

Checking Dimensions

Currently, the Tensor Types we’ve been dealing with have no type safety on the dimensions. Tensor Flow doesn't provide this information when interacting with the C library. So it’s impossible to enforce it at a low level. But this doesn’t stop us from writing wrappers that allow us to solve this.

To write these wrappers, we’re going to need to dive into dependent types. I’ll give a high level overview of what’s going on. But for some details on the basics, you should check out this tutorial . I’ll also give a shout-out to Renzo Carbonara, author of the Exinst library and other great Haskell things. He helped me a lot in crossing a couple big knowledge gaps for implementing dependent types.

Intro to Dependent Types: Sized Vectors

The simplest example for introducing dependent types is the idea of sized vectors. If you read the tutorial above, you'll see how they're implemented from scratch. A normal vector has a single type parameter, referring to what type of item the vector contains. A sized vector has an extra type parameter, and this type refers to the size of the vector. For instance, the following are valid sized vector types:

import Data.Vector.Sized (Vector, fromList)

vectorWith2 :: Vector 2 Int64
...
vectorWith6 :: Vector 6 Float
...

In the first type signature, 2 does not refer to the term 2. It refers to the type 2. That is, we’ve taken the term and promoted it to a type which has only a single value. The mechanics of how this works are confusing, but here’s the result. We can try to convert normal vectors to sized vectors. But the operation will fail if we don’t match up the size.

import Data.Vector.Sized (Vector, fromList)
import GHC.TypeLits (KnownNat)

-- fromList :: (KnownNat n) => [a] -> Maybe (Vector n a)

-- This results in a “Just” value!
success :: Maybe (Vector 2 Int64)
success = fromList [5,6]

-- The sizes don’t match, so we’ll get “Nothing”!
failure :: Maybe (Vector 2 Int64)
failure = fromList [3,1,5]

The KnownNat constraint allows us to specify that the type n refers to a single natural number. So now we can assign a type signature that encapsulates the size of the list.

A “Safe” Shape type

Now that we have a very basic understanding of dependent types, let's come up with a gameplan for Tensor Flow. The first step will be to make a new type that puts the shape into the type signature. We'll make a SafeShape type that mimics the sized vector type. Instead of storing a single number as the type, it will store the full list of dimensions. We want to create an API something like this:

-- fromShape :: Shape -> Maybe (SafeShape s)

-- Results in a “Just” value
goodShape :: Maybe (SafeShape ‘[2, 2])
goodShape = fromShape (Shape [2,2])

-- Results in Nothing
badShape :: Maybe (SafeShape ‘[2,2])
badShape = fromShape (Shape [3,3,2])

So to do this, we first define the SafeShape type. This follows the example of sized vectors. See the appendix below for compiler extensions and imports used throughout this article. In particular, you want GADTs and DataKinds.

data SafeShape (s :: [Nat]) where
  NilShape :: SafeShape '[]
  (:--) :: KnownNat m => Proxy m -> SafeShape s -> SafeShape (m ': s)

infixr 5 :--

Now we can define the toShape function. This will take our SafeShape and turn it into a normal Shape using proxies.

toShape :: SafeShape s -> Shape
toShape NilShape = Shape []
toShape ((pm :: Proxy m) :-- s) = Shape (fromInteger (natVal pm) : s')
  where
    (Shape s') = toShape s

Now for the reverse direction, we first have to make a class MkSafeShape. This class encapsulates all the types that we can turn into the SafeShape type. We’ll define instances of this class for all lists of naturals.

class MkSafeShape (s :: [Nat]) where
  mkSafeShape :: SafeShape s
instance MkSafeShape '[] where
  mkSafeShape = NilShape
instance (MkSafeShape s, KnownNat m) => MkSafeShape (m ': s) where
  mkSafeShape = Proxy :-- mkSafeShape

Now we can define our fromShape function using the MkSafeShape class. To check if it works, we’ll compare the resulting shape to the input shape and make sure they’re equal. Note this requires us to define a simple instance of Eq Shape.

instance Eq Shape where
  (==) (Shape s) (Shape r) = s == r

fromShape :: forall s. MkSafeShape s => Shape -> Maybe (SafeShape s)
fromShape shape = if toShape myShape == shape
  then Just myShape
  else Nothing
  where
    myShape = mkSafeShape :: SafeShape s

Now that we’ve done this for Shape, we can create a similar type for Tensor that will store the shape as a type parameter.

data SafeTensor v a (s :: [Nat]) where
  SafeTensor :: (TensorType a) => Tensor v a -> SafeTensor v a s

Using our Safe Types

So what has all this gotten us? Our next goal is to create a safeConstant function. This will let us create a SafeTensor wrapping a constant tensor and storing the shape. Remember, constant takes a shape and a vector without ensuring correlation between them. We want something like this:

safeConstant :: (TensorType a) => Vector n a -> SafeShape s -> SafeTensor Build a s
safeConstant elems shp = SafeTensor $ constant (toShape shp) (toList elems)

This will attach the given shape to the tensor. But there’s one piece missing. We also want to create a connection between the number of input elements and the shape. So something with shape [3,3,2] should force you to input a vector of length 18. And right now, there is no constraint between n and s.

We’ll add this with a type family called ShapeProduct. The instances will state that the correct natural type for a given list of naturals is the product of them. We define the second instance with recursion, so we'll need UndecidableInstances.

type family ShapeProduct (s :: [Nat]) :: Nat
type instance ShapeProduct '[] = 1
type instance ShapeProduct (m ': s) = m * ShapeProduct s

Now we’re almost done with this part! We can fix our safeConstant function by adding a constraint on the ShapeProduct between s and n.

safeConstant :: (TensorType a, ShapeProduct s ~ n) => Vector n a -> SafeShape s -> SafeTensor Build a s
safeConstant elems shp = SafeTensor $ constant (toShape shp) (toList elems)

Now we can write out a simple use of our safeConstant function as follows:

main :: IO (VN.Vector Int64)
main = runSession $ do
  let (shape1 :: SafeShape '[2,2]) = fromJust $ fromShape (Shape [2,2])
  let (elems1 :: Vector 4 Int64) = fromJust $ fromList [1,2,3,4]
  let (constant1 :: SafeTensor Build Int64 '[2,2]) = safeConstant elems1 shape1
  let (SafeTensor t) = constant1
  run t

We’re using fromJust as a shortcut here. But in a real program you would read your initial tensors in and check them as Maybe values. There's still the possibility for runtime failures. But this system has a couple advantages. First, it won't crash. We'll have the opportunity to handle it gracefully. Second, we do all the error checking up front. Once we've assigned types to everything, all the failure cases should be covered.

Going back to the last example, let's change something. For instance, we could make our vector have length 3 instead of 4. We’ll now get a compile error!

main :: IO (VN.Vector Int64)
main = runSession $ do
  let (shape1 :: SafeShape '[2,2]) = fromJust $ fromShape (Shape [2,2])
  let (elems1 :: Vector 3 Int64) = fromJust $ fromList [1,2,3]
  let (constant1 :: SafeTensor Build Int64 '[2,2]) = safeConstant elems1 shape1
  let (SafeTensor t) = constant1
  run t

…

    • Couldn't match type ‘4’ with ‘3’
        arising from a use of ‘safeConstant’
    • In the expression: safeConstant elems1 shape1
      In a pattern binding:
        (constant1 :: SafeTensor Build Int64 '[2, 2])
          = safeConstant elems1 shape1

Adding Type Safe Operations

Now that we’ve attached shape information to our tensors, we can define safer math operations. It's easy to write a safe addition function that ensures that the tensors have the same shape:

safeAdd :: (TensorType a, a /= Bool) => SafeTensor Build a s -> SafeTensor Build a s -> SafeTensor Build a s
safeAdd (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `add` t2)

Here’s a similar matrix multiplication function. It ensures we have 2-dimensional shapes and that the dimensions work out. Notice the two tensors share the n dimension. It must be the column dimension of the first tensor and the row dimension of the second tensor:

safeMatMul :: (TensorType a, a /= Bool, a /= Int8, a /= Int16, a /= Int64, a /= Word8, a /= ByteString)
   => SafeTensor Build a '[i,n] -> SafeTensor Build a '[n,o] -> SafeTensor Build a '[i,o]
safeMatMul (SafeTensor t1) (SafeTensor t2) = SafeTensor (t1 `matMul` t2)

Here are these functions in action:

main2 :: IO (VN.Vector Float)
main2 = runSession $ do
  let (shape1 :: SafeShape '[4,3]) = fromJust $ fromShape (Shape [4,3])
  let (shape2 :: SafeShape '[3,2]) = fromJust $ fromShape (Shape [3,2])
  let (shape3 :: SafeShape '[4,2]) = fromJust $ fromShape (Shape [4,2])
  let (elems1 :: Vector 12 Float) = fromJust $ fromList [1,2,3,4,1,2,3,4,1,2,3,4]
  let (elems2 :: Vector 6 Float) = fromJust $ fromList [5,6,7,8,9,10]
  let (elems3 :: Vector 8 Float) = fromJust $ fromList [11,12,13,14,15,16,17,18]
  let (constant1 :: SafeTensor Build Float '[4,3]) = safeConstant elems1 shape1
  let (constant2 :: SafeTensor Build Float '[3,2]) = safeConstant elems2 shape2
  let (constant3 :: SafeTensor Build Float '[4,2]) = safeConstant elems3 shape3
  let (multTensor :: SafeTensor Build Float '[4,2]) = constant1 `safeMatMul` constant2
  let (addTensor :: SafeTensor Build Float '[4,2]) = multTensor `safeAdd` constant3
  let (SafeTensor finalTensor) = addTensor
  run finalTensor

And of course we’ll get compile errors if we use incorrect dimensions anywhere. Let’s say we change multTensor to use [4,3] as its type:

• Couldn't match type ‘2’ with ‘3’
      Expected type: SafeTensor Build Float '[4, 3]
        Actual type: SafeTensor Build Float '[4, 2]
    • In the expression: constant1 `safeMatMul` constant2
…
 • Couldn't match type ‘3’ with ‘2’
      Expected type: SafeTensor Build Float '[4, 2]
        Actual type: SafeTensor Build Float '[4, 3]
    • In the expression: multTensor `safeAdd` constant3
…
• Couldn't match type ‘2’ with ‘3’
      Expected type: SafeTensor Build Float '[4, 3]
        Actual type: SafeTensor Build Float '[4, 2]
    • In the second argument of ‘safeAdd’, namely ‘constant3’

Conclusion

In this exercise we got deep into the weeds of one of the most difficult topics to learn about in Haskell. Dependent types will make your head spin at first. But we saw a concrete example of how they can allow us to detect problematic code at compile time. They are a form of documentation that also enables us to verify that our code is correct in certain ways.

Types do not replace tests (especially behavioral tests). But in this instance there are at least a few different test cases we don’t need to worry about too much. Next week, we’ll see how we can apply these principles to verifying placeholders.

If you want to learn more about the nuts and bolts of using Haskell Tensor Flow, you should check out our Tensor Flow Guide. It will guide you through the basics of adding Tensor Flow to a simple Stack project.

Maybe you’ve never used Haskell before but I’ve convinced you that dependent types are the future. If you want to try it out, download our Getting Started Checklist. You can also learn how to create and organize Haskell projects using Stack! Checkout our Stack mini-course!

Appendix: Extensions and Imports

{-# LANGUAGE GADTs                #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE UndecidableInstances #-}

import           Data.ByteString (ByteString)
import           Data.Constraint (Constraint)
import           Data.Int (Int64, Int8, Int16)
import           Data.Maybe (fromJust)
import           Data.Proxy (Proxy(..))
import qualified Data.Vector as VN
import           Data.Vector.Sized (Vector(..), toList, fromList)
import           Data.Word (Word8)
import           GHC.TypeLits (Nat, KnownNat, natVal)
import           GHC.TypeLits

import           TensorFlow.Core
import           TensorFlow.Core (Shape(..), TensorType, Tensor, Build)
import           TensorFlow.Ops (constant, add, matMul)
import           TensorFlow.Session (runSession, run)
Read More
James Bowen James Bowen

Deeper Still: Convolutional Neural Networks

Two weeks ago, we began our machine study in earnest by constructing a full neural network. But this network was still quite simple by deep learning standards. In this article, we’re going to tackle a much more difficult problem: image recognition. Of course, we’ll still be using a well known data set with well-known results, so this is only the tip of the iceberg. We'll be using the MNIST data set. This set classifies images of handwritten digits as the numbers 0-9. This problem is so well-known that the folks at Tensor Flow refer to it as the “Hello World” of machine learning.

We’ll start this problem by using a very similar approach to what we used with the Iris data set. We’ll make a fully-connected neural network with two layers, and then use the “Adam” optimizer. This will give us some decent results by our beginner standards. But MNIST is a well known problem with a very large data set. So we’re going to hold ourselves to a higher standard of accuracy this time. This will force us to use some more advanced techniques. But to start with, let’s examine what we need to change to adapt our Iris model to work for the MNIST problem. As with the last couple weeks, the code for all this is on Github if you want to follow along.

Re-use and Recycle!

Generally, we can re-use most of the code we had with Iris, which is good news! We still have to make a few adjustments here and there though. First, we’ll use some different constants. We’ll use mnistFeatures in place of irisFeatures, and mnistLabels instead of irisLabels. We’ll also bump up the size of our hidden layer and the number of samples we’ll draw on each iteration:

mnistFeatures :: Int64
mnistFeatures = 784

mnistLabels :: Int64
mnistLabels = 10

numHiddenUnits :: Int64
numHiddenUnits = 1024

sampleSize :: Int
sampleSize = 1000

We’ll also change our model to use Word8 as the result type instead Int64.

data Model = Model
 { train :: TensorData Float
         -> TensorData Word8 -- Used to be Int64
         -> Session ()
 , errorRate :: TensorData Float
             -> TensorData Word8 -- Used to be Int64
             -> SummaryTensor
             -> Session (Float, ByteString)
 }

Now we have to change how we get our input data. Our data isn’t in CSV format this time. We’ll use helper functions from the Tensor Flow library to extract the images and labels:

import TensorFlow.Examples.MNIST.Parse (readMNISTSamples, readMNISTLabels)
…
runDigits :: FilePath -> FilePath -> FilePath -> FilePath -> IO ()
runDigits trainImageFile trainLabelFile testImageFile testLabelFile = 
 withEventWriter eventsDir $ \eventWriter -> runSession $ do

   -- trainingImages, testImages :: [Vector Word8]
   trainingImages <- liftIO $ readMNISTSamples trainImageFile
   testImages <- liftIO $ readMNISTSamples testImageFile

   -- traininglabels, testLabels :: [Word8]
   trainingLabels <- liftIO $ readMNISTLabels trainLabelFile
   testLabels <- liftIO $ readMNISTLabels testLabelFile

   -- trainingRecords, testRecords :: Vector (Vector Word8, Word8)
   let trainingRecords = fromList $ zip trainingImages trainingLabels
   let testRecords = fromList $ zip testImages testLabels
   ...

Our “input” type consists of vectors of Word8 elements. These represent the intensity of various pixels. Our “output” type is Word8, referring to the actual labels (0-9). We read the images and labels from separate files. Then we zip them together to pass to our processing functions. We’ll have to make a few changes to these processing functions for this data set. First, we have to generalize the type of our randomization function:

-- Used to be IrisRecord Specific
chooseRandomRecords :: Vector a -> IO (Vector a)

Next we have to write a new encoding function that will put our data into the TensorData format. This looks like our old version, except dealing with the new tuple type instead of the IrisRecord.

convertDigitRecordsToTensorData 
 :: Vector (Vector Word8, Word8)
 -> (TensorData Float, TensorData Word8)
convertDigitRecordsToTensorData records = (input, output)
 where
   numRecords = Data.Vector.length records 
   input = encodeTensorData [fromIntegral numRecords, mnistFeatures]
     (fromList $ concatMap recordToInputs records)
   output = encodeTensorData [fromIntegral numRecords] (snd <$> records)
   recordToInputs :: (Vector Word8, Word8) -> [Float]
   recordToInputs rec = fromIntegral <$> (toList . fst) rec

And then we just have to substitute our new functions and parameters in, and we’ll be able to run our digit trainer!

Current training error 89.8
Current training error 19.300001
Current training error 13.300001
Current training error 11.199999
Current training error 8.700001
Current training error 6.5999985
Current training error 6.999999
Current training error 5.199999
Current training error 4.400003
Current training error 5.000001
Current training error 2.3000002

test error 6.830001

So our accuracy is 93.2%. This seems like an alright number. But imagine being a Post office and having 6.8% of your mail sorted into the wrong Zip Code! (This was the original use case of this data set). So let’s see if we can do better.

Convolution and Max Pooling

Now we could train our model longer. This will tend to improve our error rate. But we can also help ourselves by making our model more complex. The fundamental flaw with what we’ve got so far is that it doesn’t account for the 2D nature of the images. This means we're losing a ton of useful information. So the first thing we'll do is treat our images as being 28x28 tensors instead of 1x784. This way, our model can pick out specific areas that are significant for the identification of the digit.

One thing we want to account for is that our image might not be in the center of the frame. To account for this, we're going to apply convolution. When using convolution, we break the image into many different overlapping tiles. In our case, we’ll make our strides size “1” in every direction, and we’ll use a patch size of 5x5. So this means we’ll center a 5x5 tile around each different pixel in our image, and then come up with a score for it. That score tells us if this part of the image contains any important information. We can represent this score as a vector with many features.

So with 2D convolution, we'll be dealing with 4-dimensional tensors. The first dimension is the sample size. The second two dimensions are the shape of the image. The final dimension is the number of features of the "score" for each part of the image. So each original image starts our with a single feature for the "score" of each pixel. This score is the actual intensity of that pixel! Then each layer of convolution will act as a mini neural network per pixel, making as many features as we want.

The different sliding windows correspond to scores we store in the next layer. This example uses 3x3 patches; we'll use 5x5.

The different sliding windows correspond to scores we store in the next layer. This example uses 3x3 patches; we'll use 5x5.

Max pooling is a form of down-sampling. After our first convolution step, we’ll have scores on the 28x28 image. We’ll use 2x2 max-pooling, meaning we divide each image into 2x2 squares. Then we’ll make a new layer that is 14x14, using only the “best” score from each 2x2 box. This makes our model more efficient while keeping the most important information.

Simple demonstration of max-pooling

Simple demonstration of max-pooling

Implementing Convolutional Layers

We’ll do two rounds of convolution and max pooling. So we’ll make a function that creates a layer that performs these two steps. This will look a lot like our other neural network layer. We’ll take parameters for the size of the input and output channels of the layer, as well as the tensor itself. So our first step will be to create the weights and bias tensors using these parameters:

patchSize :: Int64
patchSize = 5

buildConvPoolLayer :: Int64 -> Int64 -> Tensor v Float -> Text
                  -> Build (Variable Float, Variable Float, Tensor Build Float)
buildConvPoolLayer inputChannels outputChannels input layerName = withNameScope layerName $ do
 weights <- truncatedNormal (vector weightsShape)
   >>= initializedVariable
 bias <- truncatedNormal (vector [outputChannels]) >>= initializedVariable
 ...
 where
   weightsShape :: [Int64]
   weightsShape = [patchSize, patchSize, inputChannels, outputChannels]

Now we’ll want to call our convolution and max pooling functions. These are still a little rough around the edges (the Haskell library is still quite young). The C versions of these functions have many optional, named attributes. For the moment there don’t seem to be any functions that use normal Haskell values for these arguments. Instead, we’ll be using OpAttr values, assign bytestring names to values.

where
 ...
 convStridesAttr = opAttr "strides" .~ ([1,1,1,1] :: [Int64])
 poolStridesAttr = opAttr "strides" .~ ([1,2,2,1] :: [Int64])
 poolKSizeAttr = opAttr "ksize" .~ ([1,2,2,1] :: [Int64])
 paddingAttr = opAttr "padding" .~ ("SAME" :: ByteString)
 dataFormatAttr = opAttr "data_format" .~ ("NHWC" :: ByteString)
 convAttrs = convStridesAttr . paddingAttr . dataFormatAttr
 poolAttrs = poolKSizeAttr . poolStridesAttr . paddingAttr . dataFormatAttr

The strides argument for convolution refers to how much we shift the window each time. The strides argument for pooling refers to how big the windows will be that we perform the pooling over. In this case, it's 2x2. Now that we have our attributes, we can call the library functions conv2D’ and maxPool’. This gives our resulting vector. We also throw in a call to relu between these steps.

buildConvPoolLayer :: Int64 -> Int64 -> Tensor v Float -> Text
                  -> Build (Variable Float, Variable Float, Tensor Build Float)
buildConvPoolLayer inputChannels outputChannels input layerName = withNameScope layerName $ do
 weights <- truncatedNormal (vector weightsShape)
   >>= initializedVariable
 bias <- truncatedNormal (vector [outputChannels]) >>= initializedVariable
 let conv = conv2D' convAttrs input (readValue weights)
       `add` readValue bias
 let results = maxPool' poolAttrs (relu conv)
 return (weights, bias, results)
 where
   ...

Modifying our Model

Now we’ll make a few updates to our model and we’ll be in good shape. First, we need to reshape our input data to be 4-dimensional. Then, we’ll apply the two convolution/pooling layers:

imageDimen :: Int32
imageDimen = 28

createModel :: Build Model
createModel = do
 let batchSize = -1 -- Allows variable sized batches
 let conv1OutputChannels = 32
 let conv2OutputChannels = 64
 let denseInputSize = 7 * 7 * 64 :: Int32 -- 3136
 let numHiddenUnits = 1024

 inputs <- placeholder [batchSize, mnistFeatures]
 outputs <- placeholder [batchSize]

 let inputImage = reshape inputs (vector [batchSize, imageDimen, imageDimen, 1])

 (convWeights1, convBiases1, convResults1) <- 
   buildConvPoolLayer 1 conv1OutputChannels inputImage "convLayer1"
 (convWeights2, convBiases2, convResults2) <-
   buildConvPoolLayer conv1OutputChannels conv2OutputChannels convResults1 "convLayer2"

Once we’re done with that, we’ll apply two fully-connected (dense) layers as we did before. Note we'll reshape our result from four dimensions back down to two:

let denseInput = reshape convResults2 (vector [batchSize, denseInputSize])
(denseWeights1, denseBiases1, denseResults1) <-
  buildNNLayer (fromIntegral denseInputSize) numHiddenUnits denseInput "denseLayer1"  
let rectifiedDenseResults = relu denseResults1
(denseWeights2, denseBiases2, denseResults2) <-
   buildNNLayer numHiddenUnits mnistLabels rectifiedDenseResults "denseLayer2"

And after that we can treat the rest of the model the same. We'll update the parameter names and add the new weights and biases to the params that the model can change.

As a review, let’s look at the dimensions of each of the intermediate tensors here. Then we can see the restrictions on the dimensions of the different operations. Each convolution step takes two four-dimensional tensors. The final dimension of argument 1 must match the third dimension of argument 2. Then the result will swap in the final dimension of argument 2. Meanwhile, pooling with a 2x2 stride size will take this resulting 4-dimensional tensor and halve each of the inner dimensions.

input: n x 784
inputImage: n x 28 x 28 x 1
convWeights1: 5 x 5 x 1 x 32
convBias1: 32
conv (first layer): n x 28 x 28 x 32
convResults1: n x 14 x 14 x 32
convWeights2:  5 x 5 x 32 x 64
conv (second layer): n x 14 x 14 x 64
convResults2: n x 7 x 7 x 64
denseInput: n x 3136
denseWeights1: 3136 x 1024
denseBias1: 1024
denseResults1: n x 1024
denseWeights2: 1024 x 10
denseBias2: 10
denseResults2: n x 10

So for each input, we’ll have a probability for all 10 of the possible inputs. We pick the greatest of these as the chosen label.

Results

We’ll run our model again, only this time we’ll use a smaller sample size (100 per training iteration). This allows us to train for more iterations (20000). This takes quite a while to train, but we get these results (printed every 1000 iterations).

Current training error 91.0
Current training error 6.0
Current training error 2.9999971
Current training error 2.9999971
Current training error 0.0
Current training error 0.0
Current training error 0.99999905
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0
Current training error 0.0

test error 1.1799991

Not too bad! Once it got going, we saw very few training errors, though still ended up with a model that was a tad overfit. The Tensor Flow MNIST expert tutorial suggests using a dropout factor. This helps diminish the effects of overfitting. But this option isn’t available to us yet in Haskell. Still, we got close to 99% accuracy, which is a success for us!

And here’s what our final graph looks like. Notice the extra layers we added for convolution:

Conclusion

So that’s it for convolutional neural networks! Our goal was to adapt our previous neural network model to recognize digits. Not only did we pick a harder problem, but we also wanted higher accuracy. We achieved this by using more advanced machine learning techniques. Convolution allowed us to use the full 2x2 nature of the image. It also checked for the digit no matter where it was in the image. Max pooling enabled us to make our algorithm more efficient while keeping the most important information.

If you’re itching to see what else Haskell can do with Tensor Flow, check out our Tensor Flow Guide. It’ll walk you through some of the trickier parts of getting the library working on your local machine. It will also go through the most important information you need to know about the types in this library.

If you’re new to Haskell, here are a couple more resources you can dig into before you try your hand at Tensor Flow. First, there’s our Getting Started Checklist. This will point you toward some helpful resources for learning the language. Next, you can check out our Stack mini-course so you can learn how to organize a project! If you want to use Tensor Flow in Haskell, Stack is a must!

Read More
James Bowen James Bowen

Putting the Flow in Tensor Flow!

Last week we built our first neural network and used it on a real machine learning problem. We took the Iris data set and built a classifier that took in various flower measurements. It determined, with decent accuracy, the type of flower the measurements referred to.

But we’ve still only seen half the story of Tensor Flow! We’ve constructed many tensors and combined them in interesting ways. We can imagine what is going on with the “flow”, but we haven’t seen a visual representation of that yet.

We’re in luck though, thanks to the Tensor Board application. With it, we can visualize the computation graph we've created. We can also track certain values throughout our program run. In this article, we’ll take our Iris example and show how we can add Tensor Board features to it. Here's the Github repo with all the code so you can follow along!

Add an Event Writer

The first thing to understand about Tensor Board is that it gets its data from a source directory. While we’re running our system, we have to direct it to write events to that directory. This will allow Tensor Board to know what happened in our training run.

eventsDir :: FilePath
eventsDir = "/tmp/tensorflow/iris/logs/"

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = withEventWriter eventsDir $ \eventWriter -> runSession $ do
...

By itself, this doesn’t write anything down into that directory though! To understand the consequences of this, let’s boot up tensor board.

Running Tensor Board

Running our executable again doesn't bring up tensor board. It merely logs the information that Tensor Board uses. To actually see that information, we’ll run the tensorboard command.

>> tensorboard --logdir=’/tmp/tensorflow/iris/logs’
Starting TensorBoard 47 at http://0.0.0.0:6006

Then we can point our web browser at the correct port. Since we haven't written anything to the file yet, there won’t be much for us to see other than some pretty graphics. So let’s start by logging our graph. This is actually quite easy! Remember our model? We can use the logGraph function combined with our event writer so we can see it.

model <- build createModel
logGraph eventWriter createModel

Now when we refresh Tensor Flow, we’ll see our system’s graph.

What the heck is going on here?

What the heck is going on here?

But, it’s very large and very confusing. The names of all the nodes are a little confusing, and it’s not clear what data is going where. Plus, we have no idea what’s going on with our error rate or anything like that. Let’s make a couple adjustments to fix this.

Adding Summaries

So the first step is to actually specify some measurements that we’ll have Tensor Board plot for us. One node we can use is a “scalar summary”. This provides us with a summary of a particular value over the course of our training run. Let’s do this with our errorRate node. We can use the simple scalarSummary function.

errorRate_ <- render $ 1 - (reduceMean (cast correctPredictions))
scalarSummary "Error" errorRate_

The second type of summary is a histogram summary. We use this on a particular tensor to see the distribution of its values over the course of the run. Let’s do this with our second set of weights. We need to use readValue to go from a Variable to a Tensor.

(finalWeights, finalBiases, finalResults) <-
  buildNNLayer numHiddenUnits irisLabels rectifiedHiddenResults
histogramSummary "Weights" (readValue finalWeights)

So let’s run tensor flow again. We would expect to see these new values show up under the Scalars and Histograms tabs. But they don’t. This is because we still to write these results to our event writer. And this turns out to be a little complicated. First, before we start training, we have to create a tensor representing all our summaries.

logGraph eventWriter createModel
summaryTensor <- build mergeAllSummaries

Now if we had no placeholders, we could run this tensor whenever we wanted, and it would output the values. But our summary tensors depend on the input placeholders, which complicates the matter. So here’s what we’ll do. We’ll only write out the summaries when we check our error rate (every 100 steps). To do this, we have to change our error rate in the model to take the summary tensor as an extra argument. We’ll also have it add a ByteString as a return value to the original Float.

data Model = Model
  { train :: TensorData Float
          -> TensorData Int64
          -> Session ()
  , errorRate :: TensorData Float
              -> TensorData Int64
              -> SummaryTensor
              -> Session (Float, ByteString)
  }

Within our model definition, we’ll use this extra parameter. It will run both the errorRate_ tensor AND the summary tensor together with the feeds:

return $ Model
  , train = ...
  , errorRate = \inputFeed outputFeed summaryTensor -> do
      (errorTensorResult, summaryTensorResult) <- runWithFeeds
        [ feed inputs inputFeed
        , feed outputs outputFeed
        ]
        (errorRate_, summaryTensor)
      return (unScalar errorTensorResult, unScalar summaryTensorResult)

Now we need to modify our calls to errorRate below. We’ll pass the summary tensor as an argument, and get the bytes as output. We’ll write it to our event writer (after decoding), and then we’ll be done!

-- Training
  forM_ ([0..1000] :: [Int]) $ \i -> do
    trainingSample <- liftIO $ chooseRandomRecords trainingRecords
    let (trainingInputs, trainingOutputs) = convertRecordsToTensorData trainingSample
    (train model) trainingInputs trainingOutputs
    when (i `mod` 100 == 0) $ do
      (err, summaryBytes) <- (errorRate model) trainingInputs trainingOutputs summaryTensor
      let summary = decodeMessageOrDie summaryBytes
      liftIO $ putStrLn $ "Current training error " ++ show (err * 100)
      logSummary eventWriter (fromIntegral i) summary

  liftIO $ putStrLn ""

  -- Testing
  let (testingInputs, testingOutputs) = convertRecordsToTensorData testRecords
  (testingError, _) <- (errorRate model) testingInputs testingOutputs summaryTensor
  liftIO $ putStrLn $ "test error " ++ show (testingError * 100)

Now we can see what our summaries look like by running tensor board again!

Scalar Summary of our Error Rate

Scalar Summary of our Error Rate

Histogram summary of our final weights.

Histogram summary of our final weights.

Annotating our Graph

Now let’s look back to our graph. It’s still a bit confusing. We can clean it up a lot by creating “name scopes”. A name scope is part of the graph that we set aside under a single name. When Tensor Board generates our graph, it will create one big block for the scope. We can then zoom in and examine the individual nodes if we want.

We’ll make three different scopes. First, we’ll make a scope for each of the hidden layers of our neural network. This is quite easy, since we already have a function for creating these. All we have to do is make the function take an extra parameter for the name of the scope we want. Then we wrap the whole function within the withNameScope function.

buildNNLayer :: Int64 -> Int64 -> Tensor v Float -> Text
             -> Build (Variable Float, Variable Float, Tensor Build Float)
buildNNLayer inputSize outputSize input layerName = withNameScope layerName $ do
  weights <- truncatedNormal (vector [inputSize, outputSize]) >>= initializedVariable
  bias <- truncatedNormal (vector [outputSize]) >>= initializedVariable
  let results = (input `matMul` readValue weights) `add` readValue bias
  return (weights, bias, results)

We supply our name further down in the code:

(hiddenWeights, hiddenBiases, hiddenResults) <- 
  buildNNLayer irisFeatures numHiddenUnits inputs "layer1"
let rectifiedHiddenResults = relu hiddenResults
(finalWeights, finalBiases, finalResults) <-
  buildNNLayer numHiddenUnits irisLabels rectifiedHiddenResults "layer2"

Now we’ll add a scope around all our error calculations. First, we combine these into an action wrapped in withNameScope. Then, observing that we need the errorRate_ and train_ steps, we return those from the block. That’s it!

(errorRate_, train_) <- withNameScope "error_calculation" $ do
    actualOutput <- render $ cast $ argMax finalResults (scalar (1 :: Int64))
    let correctPredictions = equal actualOutput outputs
    er <- render $ 1 - (reduceMean (cast correctPredictions))
    scalarSummary "Error" er

    let outputVectors = oneHot outputs (fromIntegral irisLabels) 1 0
    let loss = reduceMean $ fst $ softmaxCrossEntropyWithLogits finalResults outputVectors
    let params = [hiddenWeights, hiddenBiases, finalWeights, finalBiases]
    tr <- minimizeWith adam loss params
    return (er, tr)

Now when we look at our graph, we see that it’s divided into three parts: our two layers, and our error calculation. All the information flows among these three parts (as well as the "Adam" optimizer portion).

Much Better

Much Better

Conclusion

By default, Tensor Board graphs can look a little messy. But by adding a little more information to the nodes and using scopes, you can paint a much clearer picture. You can see how the data flows from one end of the application to the other. We can also use summaries to track important information about our graph. We’ll use this most often for the loss function or error rate. Hopefully, we'll see it decline over time.

Next week we’ll add some more complexity to our neural networks. We'll see new tensors for convolution and max pooling. This will allow us to solve the more difficult MNIST digit recognition problem. Stay tuned!

If you’re itching to try out some Tensor Board functionality for yourself, check out our in-depth Tensor Flow guide. It goes into more detail about the practical aspects of using this library. If you want to get the Haskell Tensor Flow library running on your local machine, check it out! Trust me, it's a little complicated, unless you're a Stack wizard already!

And if this is your first exposure to Haskell, try it out! Take a look at our guide to getting started with the language!

Read More
James Bowen James Bowen

Digging in Deep: Solving a Real Problem with Haskell Tensor Flow

Last week we got acquainted with the core concepts of Tensor Flow. We learned about the differences between constants, placeholders, and variable tensors. Both the Haskell and Python bindings have functions to represent these. The Python version was a bit simpler though. Once we had our tensors, we wrote a program that “learned” a simple linear equation.

This week, we’re going to solve an actual machine learning problem. We’re going to use the Iris data set, which contains measurements of different Iris flowers. Each flower belongs to one of three species. Our program will "learn" a function choosing the species from the measurements. This function will involved a fully-connected neural network.

Formatting our Input

The first step in pretty much any machine learning problem is data processing. After all, our data doesn’t magically get resolved into Haskell data types. Luckily, Cassava is a great library to help us out. The Iris data set consists of data in .csv files that each have a header line and then a series of records. They look a bit like this:

120,4,setosa,versicolor,virginica
6.4,2.8,5.6,2.2,2
5.0,2.3,3.3,1.0,1
4.9,2.5,4.5,1.7,2
4.9,3.1,1.5,0.1,0
...

Each line contains one record. A record has four flower measurements, and a final label. In this case, we have three types of flowers we are trying to classify between: Iris Setosa, Iris Versicolor, and Iris Virginica. So the last column contains the numbers 0,1, and 2, corresponding to these respective classes.

Let's create a data type representing each record. Then we can parse the file line-by-line. Our IrisRecord type will contain the feature data and the resulting label. This type will act as a bridge between our raw data and the tensor format we’ll need to run our learning algorithm. We’ll derive the “Generic” typeclass for our record type, and use this to get FromRecord. Once our type has an instance for FromRecord, we can parse it with ease. As a note, throughout this article, I’ll be omitting the imports section from the code samples. I’ve included a full list of imports from these files as an appendix at the bottom. We'll also be using the OverloadedLists extension throughout.

{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedLists #-}

...

data IrisRecord = IrisRecord
  { field1 :: Float
  , field2 :: Float
  , field3 :: Float
  , field4 :: Float
  , label  :: Int64
  }
  deriving (Generic)

instance FromRecord IrisRecord

Now that we have our type, we’ll write a function, readIrisFromFile, that will read our data in from a CSV file.

readIrisFromFile :: FilePath -> IO (Vector IrisRecord)
readIrisFromFile fp = do
  contents <- readFile fp
  let contentsAsBs = pack contents
  let results = decode HasHeader contentsAsBs :: Either String (Vector IrisRecord)
  case results of
    Left err -> error err
    Right records -> return records

We won’t want to always feed our entire data set into our training system. So given a whole slew of these items, we should be able to pick out a random sample.

sampleSize :: Int
sampleSize = 10

chooseRandomRecords :: Vector IrisRecord -> IO (Vector IrisRecord)
chooseRandomRecords records = do
  let numRecords = Data.Vector.length records
  chosenIndices <- take sampleSize <$> shuffleM [0..(numRecords - 1)]
  return $ fromList $ map (records !) chosenIndices

Once we’ve selected our vector of records to use for each run, we’re still not done. We need to take these records and transform them into the TensorData that we’ll feed into our algorithm. We create items of TensorData by feeding in a shape and then a 1-dimensional vector of values. First, we need to know the shapes of our input and output. Both of these depend on the number of rows in the sample. The “input” will have a column for each of the four features in our set. The output meanwhile will have a single column for the label values.

irisFeatures :: Int64
irisFeatures = 4

irisLabels :: Int64
irisLabels = 3

convertRecordsToTensorData :: Vector IrisRecord -> (TensorData Float, TensorData Int64)
convertRecordsToTensorData records = (input, output)
  where
    numRecords = Data.Vector.length records 
    input = encodeTensorData [fromIntegral numRecords, irisFeatures] (undefined)
    output = encodeTensorData [fromIntegral numRecords] (undefined)

Now all we need to do is take the various records and turn them into one dimensional vectors to encode. Here’s the final function:

convertRecordsToTensorData :: Vector IrisRecord -> (TensorData Float, TensorData Int64)
convertRecordsToTensorData records = (input, output)
  where
    numRecords = Data.Vector.length records 
    input = encodeTensorData [fromIntegral numRecords, irisFeatures]
      (fromList $ concatMap recordToInputs records)
    output = encodeTensorData [fromIntegral numRecords] (label <$> records)
    recordToInputs :: IrisRecord -> [Float]
    recordToInputs rec = [field1 rec, field2 rec, field3 rec, field4 rec]

Neural Network Basics

Now that we’ve got that out of the way, we can start writing our model. Remember, we want to perform two different actions with our model. First, we want to be able to take our training input and train the weights. Second, we want to be able to pass a test data set and determine the error rate. We can represent these two different functions as a single Model object. Remember the Session monad, where we run all our Tensor Flow activities. The training will run an action that alters the variables but returns nothing. The error rate calculation will return us a float value.

data Model = Model
  { train :: TensorData Float -- Training input
          -> TensorData Int64 -- Training output
          -> Session ()
  , errorRate :: TensorData Float -- Test input
              -> TensorData Int64 -- Test output
              -> Session Float
  }

Now we’re going to build a fully-connected neural network. We’ll have 4 input units (1 for each of the different features), and then we’ll have 3 output units (1 for each of the classes we’re trying to represent). In the middle, we’ll use a hidden layer consisting of 10 units. This means we’ll need two sets of weights and biases. We’ll write a function that, when given dimensions, will give us the variable tensors for each layer. We want the weight and bias tensors, plus the result tensor of the layer.

buildNNLayer :: Int64 -> Int64 -> Tensor v Float
             -> Build (Variable Float, Variable Float, Tensor Build Float)
buildNNLayer inputSize outputSize input = do
  weights <- truncatedNormal (vector [inputSize, outputSize]) >>= initializedVariable
  bias <- truncatedNormal (vector [outputSize]) >>= initializedVariable
  let results = (input `matMul` readValue weights) `add` readValue bias
  return (weights, bias, results)

We do this in the Build monad, which allows us to construct variables, among other things. We’ll use a truncatedNormal distribution for all our variables to keep things simple. We specify the size of each variable in a vector tensor, and then initialize them. Then we’ll create the resulting tensor by multiplying the input by our weights and adding the bias.

Constructing our Model

Now we’ll start building our Model object, again within the Build monad. We begin by specifying our input and output placeholders, as well the number of hidden units. We’ll also use a batchSize of -1 to account for the fact that we want a variable number of input samples.

irisFeatures :: Int64
irisFeatures = 4

irisLabels :: Int64
irisLabels = 3
-- ^^ From above

createModel :: Build Model
createModel = do
  let batchSize = -1 -- Allows variable sized batches
  let numHiddenUnits = 10
  inputs <- placeholder [batchSize, irisFeatures]
  outputs <- placeholder [batchSize]

Then we’ll get the nodes for the two layers of variables, as well as their results. Between the layers, we’ll add a “rectifier” activation function relu:

(hiddenWeights, hiddenBiases, hiddenResults) <- 
  buildNNLayer irisFeatures numHiddenUnits inputs
let rectifiedHiddenResults = relu hiddenResults
(finalWeights, finalBiases, finalResults) <-
  buildNNLayer numHiddenUnits irisLabels rectifiedHiddenResults

Now we have to get the inferred classes of each output. This means calling argMax to take the class with the highest probability. We’ll also cast the vector and then render it. These are some Haskell-Tensor-Flow specific terms for getting tensors to the right type. Next, we compare that against our output placeholders to see how many we got correct. Then we’ll make a node for calculating the error rate for this run.

actualOutput <- render $ cast $ argMax finalResults (scalar (1 :: Int64))
let correctPredictions = equal actualOutput outputs
errorRate_ <- render $ 1 - (reduceMean (cast correctPredictions))

Now we have to actually do the work of training. First, we’ll make oneHot vectors for our expected outputs. This means converting the label 0 into the vector [1,0,0], and so on. We’ll compare these values against our results (before we took the max), and this gives us our loss function. Then we will make a list of the parameters we want to train. The adam optimizer will minimize our loss function while modifying the params.

let outputVectors = oneHot outputs (fromIntegral irisLabels) 1 0
let loss = reduceMean $ fst $ softmaxCrossEntropyWithLogits finalResults outputVectors
let params = [hiddenWeights, hiddenBiases, finalWeights, finalBiases]
train_ <- minimizeWith adam loss params

Now we’ve got our errorRate_ and train_ nodes ready. There's one last step here. We have to plug in for the placeholder values and create functions that will take in the tensor data. Remember the feed pattern from last week? We use it again here. Finally, our model is complete!

return $ Model
  { train = \inputFeed outputFeed -> 
      runWithFeeds
        [ feed inputs inputFeed
        , feed outputs outputFeed
        ]
        train_
  , errorRate = \inputFeed outputFeed -> unScalar <$>
      runWithFeeds
        [ feed inputs inputFeed
        , feed outputs outputFeed
        ]
        errorRate_
  }

Tying it all together

Now we’ll write our main function that will run the session. It will have three stages. In the preparation stage, we’ll load our data, and use the build function to get our model. Then we’ll train our model for 1000 steps by choosing samples and converting our records to data. Every 100 steps, we'll print the output. Finally, we’ll determine the resulting error ratio by using the test data.

runIris :: FilePath -> FilePath -> IO ()
runIris trainingFile testingFile = runSession $ do
  -- Preparation
  trainingRecords <- liftIO $ readIrisFromFile trainingFile
  testRecords <- liftIO $ readIrisFromFile testingFile
  model <- build createModel

  -- Training
  forM_ ([0..1000] :: [Int]) $ \i -> do
    trainingSample <- liftIO $ chooseRandomRecords trainingRecords
    let (trainingInputs, trainingOutputs) = convertRecordsToTensorData trainingSample
    (train model) trainingInputs trainingOutputs
    when (i `mod` 100 == 0) $ do
      err <- (errorRate model) trainingInputs trainingOutputs
      liftIO $ putStrLn $ "Current training error " ++ show (err * 100)

  liftIO $ putStrLn ""

  -- Testing
  let (testingInputs, testingOutputs) = convertRecordsToTensorData testRecords
  testingError <- (errorRate model) testingInputs testingOutputs
  liftIO $ putStrLn $ "test error " ++ show (testingError * 100)

  return ()

Results

So when we actually run all this output, we’ll get the following results on our test set.

Current training error 60.000004
Current training error 30.000002
Current training error 39.999996
Current training error 19.999998
Current training error 10.000002
Current training error 10.000002
Current training error 19.999998
Current training error 19.999998
Current training error 10.000002
Current training error 10.000002
Current training error 0.0

test error 3.333336

Our test sample size was 30, so this means we got 29/30 this time around. Results change though from run to run (I obviously used the best results I found). Since our sample size is so small, we have high entropy here (sometimes the error rate is like 40%). Generally we’ll want to train longer on a larger test set, so that we get more consistent results, but this is a good start.

Conclusion

In this article we went over the basics of making a neural network using the Haskell Tensor Flow library. We made a fully-connected neural network and fed in real data we parsed using the Cassava library. This network was able to learn a function to classify flowers from the Iris data set. Considering the small amount of data, we got some good results.

Come back next week, where we’ll see how we can add some more summary information to our tensor flow graph. We’ll use the tensor board application to view our graph in a visual format.

For more details on installing the Haskell Tensor Flow system, check out our In-Depth Tensor Flow Tutorial. It should walk you through the important steps in running the code on your own machine.

Perhaps you’ve never tried Haskell before at all, and want to see what it’s like. Maybe I’ve convinced you that Haskell is in fact the future of AI. In that case, you should check out our Getting Started Checklist for some tools on starting with the language.

Appendix: All Imports

Documentation for Haskell Tensor Flow is still a major work in progress. So I want to make sure I explicitly list the modules you need to import for all the different functions we used here.

import Control.Monad (forM_, when)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString.Lazy.Char8 (pack)
import Data.Csv (FromRecord, decode, HasHeader(..))
import Data.Int (Int64)
import Data.Vector (Vector, length, fromList, (!))
import GHC.Generics (Generic)
import System.Random.Shuffle (shuffleM)

import TensorFlow.Core (TensorData, Session, Build, render, runWithFeeds, feed, unScalar, build,
                        Tensor, encodeTensorData)
import TensorFlow.Minimize (minimizeWith, adam)
import TensorFlow.Ops (placeholder, truncatedNormal, add, matMul, relu,
                      argMax, scalar, cast, oneHot, reduceMean, softmaxCrossEntropyWithLogits, 
                      equal, vector)
import TensorFlow.Session (runSession)
import TensorFlow.Variable (readValue, initializedVariable, Variable)
Read More
James Bowen James Bowen

Starting out with Haskell Tensor Flow

Last week we discussed the burgeoning growth of AI systems. We saw several examples of how those systems are impacting our lives more and more. I made the case that we ought to focus more on reliability when making architecture choices. After all, people’s lives might be at stake when we right code now. Naturally, I suggested Haskell as a prime candidate for developing reliable AI systems.

So now we’ll actually write some Haskell machine learning code. We'll focus on the Tensor Flow bindings library. I first got familiar with this library back at BayHac in April. I’ve spent the last couple months learning both Tensor Flow as a whole and the Haskell library. In this first article, we’ll go over the basic concepts of Tensor Flow. We'll see how they’re implemented in Python (the most common language for TF). We'll then translate these concepts to Haskell.

Note this series will not be a general introduction to the concept of machine learning. There is a fantastic series on Medium about that called Machine Learning is Fun! If you’re interested in learning the basic concepts, I highly recommend you check out part 1 of that series. Many of the ideas in my own article series will be a lot clearer with that background.

Tensors

Tensor Flow is a great name because it breaks the library down into the two essential concepts. First up are tensors. These are the primary vehicle of data representation in Tensor Flow. Low-dimensional tensors are actually quite intuitive. But there comes a point when you can’t really visualize what’s going on, so you have to let the theoretical idea guide you.

In the world of big data, we represent everything numerically. And when you have a group of numbers, a programmer’s natural instinct is to put those in an array.

[1.0, 2.0, 3.0, 6.7]

Now what do you do if you have a lot of different arrays of the same size and you want to associate them together? You make a 2-dimensional array (an array of arrays), which we also refer to as a matrix.

[[1.0, 2.0, 3.0, 6.7],
[5.0, 10.0, 3.0, 12.9],
[6.0, 12.0, 15.0, 13.6],
[7.0, 22.0, 8.0, 5.3]]

Most programmers are pretty familiar with these concepts. Tensors take this idea and keep extending it. What happens when you have a lot of matrices of the same size? You can group them together as an array of matrices. We could call this a three-dimensional matrix. But “tensor” is the term we’ll use for this data representation in all dimensions.

Every tensor has a degree. We could start with a single number. This is a tensor of degree 0. Then a normal array is a tensor of degree 1. Then a matrix is a tensor of degree 2. Our last example would be a tensor of degree 3. And you can keep adding these on to each other, ad infinitum.

Every tensor has a shape. The shape is an array representing the dimensions of the tensor. The length of this array will be the degree of the tensor. So a number will have the empty list as its shape. An array will have a list of length 1 containing the length of the array. A matrix will have a list of length 2 containing its number of rows and columns. And so on. There are a few different ways we can represent tensors in code, but we'll get to that in a bit.

Go with the Flow

The second important concept to understand is how Tensor Flow performs computations. Machine learning generally involves simple math operations. A lot of simple math operations. Since the scale is so large, we need to perform these operations as fast as possible. We need to use software and hardware that is optimized for this specific task. This necessitates having a low-level code representation of what’s going on. This is easier to achieve in a language like C, instead of Haskell or Python.

We could have the bulk of our code in Haskell, but perform the math in C using a Foreign Function Interface. But these interfaces have a large overhead, so this is likely to negate most of the gains we get from using C.

Tensor Flow’s solution to this problem is that we first build up a graph describing all our computations. Then once we have described that, we “run” our graph using a “session”. Thus it performs the entire language conversion process at once, so the overhead is lower.

If this sounds familiar, it's because this is how actions tend to work in Haskell (in some sense). We can, for instance, describe an IO action. And this action isn’t a series of commands that we execute the moment they show up in the code. Rather, the action is a description of the operations that our program will perform at some point. It’s also similar to the concept of Effectful programming. We’ll explore that topic in the future on this blog.

So what does our computation graph look like? We'll, each tensor is a node. Then we can make other nodes for "operations", that take tensors as input. For instance, we can “add” two tensors together, and this is another node. We’ll see in our example how we build up the computational graph, and then run it.

One of the awesome features of Tensor Flow is the Tensor Board application. It allows you to visualize your graph of computations. We’ll see how to do this later in the series.

Coding Tensors

So at this point we should start examining how we actually create tensors in our code. We’ll start by looking at how we do this in Python, since the concepts are a little easier to understand that way. There are three types of tensors we’ll consider. The first are “constants”. These represent a set of values that do not change. We can use these values throughout our model training process, and they'll be the same each time. Since we define the values for the tensor up front, there’s no need to give any size arguments. But we will specify the datatype that we’ll use for them.

import tensorflow as tf

node1 = tf.constant(3.0, dtype=tf.float32)
node2 = tf.constant(4.0, dtype=tf.float32)

Now what can we actually do with these tensors? Well for a quick sample, let’s try adding them. This creates a new node in our graph that represents the addition of these two tensors. Then we can “run” that addition node to see the result. To encapsulate all our information, we’ll create a “Session”:

import tensorflow as tf

node1 = tf.constant(3.0, dtype=tf.float32)
node2 = tf.constant(4.0, dtype=tf.float32)
additionNode = tf.add(node1, node2)

sess = tf.Session()
result = sess.run(additionNode)
print result

“””
Output:
7.0
“””

Next up are placeholders. These are values that we change each run. Generally, we will use these for the inputs to our model. By using placeholders, we'll be able to change the input and train on different values each time. When we “run” a session, we need to assign values to each of these nodes.

We don’t know the values that will go into a placeholder, but we still assign the type of data at construction. We can also assign a size if we like. So here’s a quick snippet showing how we initialize placeholders. Then we can assign different values with each run of the application. Even though our placeholder tensors don’t have values, we can still add them just as we could with constant tensors.

node1 = tf.placeholder(tf.float32)
node2 = tf.placeholder(tf.float32)
adderNode = tf.add(node1, node2)

sess = tf.Session()
result1 = sess.run(adderNode, {node1: 3, node2: 4.5 })
result2 = sess.run(adderNode, {node1: 2.7, node2: 8.9 })
print(result1)
print(result2)

"""
Output:
7.5
11.6
"""

The last type of tensor we’ll use are variables. These are the values that will constitute our “model”. Our goal is to find values for these parameters that will make our model fit the data well. We’ll supply a data type, as always. In this situation, we’ll also provide an initial constant value. Normally, we’d want to use a random distribution of some kind. The tensor won’t actually take on its value until we run a global variable initializer function. We’ll have to create this initializer and then have our session object run it before we get going.

w = tf.Variable([3], dtype=tf.float32)
b = tf.Variable([1], dtype=tf.float32)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

Now let’s use our variables to create a “model” of sorts. For this article we'll make a simple linear model. Let’s create additional nodes for our input tensor and the model itself. We’ll let w be the weights, and b be the “bias”. This means we’ll construct our final value by w*x + b, where x is the input.

w = tf.Variable([3], dtype=tf.float32)
b = tf.Variable([1], dtype=tf.float32)
x = tf.placeholder(dtype=tf.float32)
linear_model = w * x + b

Now, we want to know how good our model is. So let’s compare it to y, an input of our expected values. We’ll take the difference, square it, and then use the reduce_sum library function to get our “loss”. The loss measures the difference between what we want our model to represent and what it actually represents.

w = tf.Variable([3], dtype=tf.float32)
b = tf.Variable([1], dtype=tf.float32)
x = tf.placeholder(dtype=tf.float32)
linear_model = w * x + b
y = tf.placeholder(dtype=tf.float32)
squared_deltas = tf.square(linear_model - y)
loss = tf.reduce_sum(squared_deltas)

Each line here is a different tensor, or a new node in our graph. We’ll finish up our model by using the built in GradientDescentOptimizer with a learning rate of 0.01. We’ll set our training step as attempting to minimize the loss function.

optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

Now we’ll run the session, initialize the variables, and run our training step 1000 times. We’ll pass a series of inputs with their expected outputs. Let's try to learn the line y = 5x - 1. Our expected output y values will assume this.

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for i in range(1000):
    sess.run(train, {x: [1, 2, 3, 4], y: [4,9,14,19]})

print(sess.run([W,b]))

At the end we print the weights and bias, and we see our results!

[array([ 4.99999475], dtype=float32), array([-0.99998516], dtype=float32)]

So we can see that our learned values are very close to the correct values of 5 and -1!

Representing Tensors in Haskell

So now at long last, I’m going to get into some of the details of how we apply these tensor concepts in Haskell. Like strings and numbers, we can’t have this one “Tensor” type in Haskell, since that type could really represent some very different concepts. For a deeper look at the tensor types we’re dealing with, check out our in depth guide.

In the meantime, let’s go through some simple code snippets replicating our efforts in Python. Here’s how we make a few constants and add them together. Do note the “overloaded lists” extension. It allows us to represent different types with the same syntax as lists. We use this with both Shape items and Vectors:

{-# LANGUAGE OverloadedLists #-}

import Data.Vector (Vector)
import TensorFlow.Ops (constant, add)
import TensorFlow.Session (runSession, run)

runSimple :: IO (Vector Float)
runSimple = runSession $ do
  let node1 = constant [1] [3 :: Float]
  let node2 = constant [1] [4 :: Float]
  let additionNode = node1 `add` node2
  run additionNode

main :: IO ()
main = do
  result <- runSimple
  print result

{-
Output:
[7.0]
-}

We use the constant function, which takes a Shape and then the value we want. We’ll create our addition node and then run it to get the output, which is a vector with a single float. We wrap everything in the runSession function. This encapsulates the initialization and running actions we saw in Python.

Now suppose we want placeholders. This is a little more complicated in Haskell. We’ll be using two placeholders, as we did in Python. We’ll initialized them with the placeholder function and a shape. We’ll take arguments to our function for the input values. To actually pass the parameters to fill in the placeholders, we have to use what we call a “feed”.

We know that our adderNode depends on two values. So we’ll write our run-step as a function that takes in two “feed” values, one for each placeholder. Then we’ll assign those feeds to the proper nodes using the feed function. We’ll put these in a list, and pass that list as an argument to runWithFeeds. Then, we wrap up by calling our run-step on our input data. We’ll have to encode the raw vectors as tensors though.

import TensorFlow.Core (Tensor, Value, feed, encodeTensorData)
import TensorFlow.Ops (constant, add, placeholder)
import TensorFlow.Session (runSession, run, runWithFeeds)

import Data.Vector (Vector)

runPlaceholder :: Vector Float -> Vector Float -> IO (Vector Float)
runPlaceholder input1 input2 = runSession $ do
  (node1 :: Tensor Value Float) <- placeholder [1]
  (node2 :: Tensor Value Float) <- placeholder [1]
  let adderNode = node1 `add` node2
  let runStep = \node1Feed node2Feed -> runWithFeeds 
        [ feed node1 node1Feed
        , feed node2 node2Feed
        ] 
        adderNode
  runStep (encodeTensorData [1] input1) (encodeTensorData [1] input2)

main :: IO ()
main = do
  result1 <- runPlaceholder [3.0] [4.5]
  result2 <- runPlaceholder [2.7] [8.9]
  print result1
  print result2

{-
Output:
[7.5]
[11.599999] -- Yay rounding issues!
-}

Now we’ll wrap up by going through the simple linear model scenario we already saw in Python. Once again, we’ll take two vectors as our inputs. These will be the values we try to match. Next, we’ll use the initializedVariable function to get our variables. We don’t need to call a global variable initializer. But this does affect the state of the session. Notice that we pull it out of the monad context, rather than using let. (We also did for placeholders.)

import TensorFlow.Core (Tensor, Value, feed, encodeTensorData, Scalar(..))
import TensorFlow.Ops (constant, add, placeholder, sub, reduceSum, mul)
import TensorFlow.GenOps.Core (square)
import TensorFlow.Variable (readValue, initializedVariable, Variable)
import TensorFlow.Session (runSession, run, runWithFeeds)
import TensorFlow.Minimize (gradientDescent, minimizeWith)

import Control.Monad (replicateM_)
import qualified Data.Vector as Vector
import Data.Vector (Vector)

runVariable :: Vector Float -> Vector Float -> IO (Float, Float)
runVariable xInput yInput = runSession $ do
  let xSize = fromIntegral $ Vector.length xInput
  let ySize = fromIntegral $ Vector.length yInput
  (w :: Variable Float) <- initializedVariable 3
  (b :: Variable Float) <- initializedVariable 1
  …

Next, we’ll make our placeholders and linear model. Then we’ll calculate our loss function in much the same way we did before. Then we’ll use the same feed trick to get our placeholders plugged in.

runVariable :: Vector Float -> Vector Float -> IO (Float, Float)
  ...
  (x :: Tensor Value Float) <- placeholder [xSize]
  let linear_model = ((readValue w) `mul` x) `add` (readValue b)
  (y :: Tensor Value Float) <- placeholder [ySize]
  let square_deltas = square (linear_model `sub` y)
  let loss = reduceSum square_deltas
  trainStep <- minimizeWith (gradientDescent 0.01) loss [w,b] 
  let trainWithFeeds = \xF yF -> runWithFeeds
        [ feed x xF
        , feed y yF
        ]
        trainStep
…

Finally, we’ll run our training step 1000 times on our input data. Then we’ll run our model one more time to pull out the values of our weights and bias. Then we’re done!

runVariable :: Vector Float -> Vector Float -> IO (Float, Float)
...
  replicateM_ 1000 
    (trainWithFeeds (encodeTensorData [xSize] xInput) (encodeTensorData [ySize] yInput))
  (Scalar w_learned, Scalar b_learned) <- run (readValue w, readValue b)
  return (w_learned, b_learned)

main :: IO ()
main = do
  results <- runVariable [1.0, 2.0, 3.0, 4.0] [4.0, 9.0, 14.0, 19.0]
  print results

{-
Output:
(4.9999948,-0.99998516)
-}

Conclusion

Hopefully this article gave you a taste of some of the possibilities of Tensor Flow in Haskell. We saw a quick introduction to the fundamentals of Tensor Flow. We saw three different kinds of tensors. We then saw code examples both in Python and in Haskell. Finally, we went over a very quick example of a simple linear model and saw how we could learn values to fit that model. Next week, we’ll do a more complicated learning problem. We’ll use the classic “Iris” flower data set and train a classifier using a full neural network.

If you want more details, you should check out FREE Haskell Tensor Flow Guide. It will walk you through using the Tensor Flow library as a dependency and getting a basic model running!

Perhaps you’re completely new to Haskell but intrigued by the possibilities of using it for machine learning or anything else. You should download our Getting Started Checklist! It has some great resources on installing Haskell and learning the core concepts.

Read More
James Bowen James Bowen

The Future is Functional: Haskell and the AI-Native World

As regular readers of this blog know, I love talking about the future of Haskell as a language. I’m interested in ways we can shape the future of programming in a way that will help Haskell grow. I've mentioned network effects as a major hindrance a couple different times. Companies are reluctant to try Haskell since there aren't that many Haskell developers. As a result, fewer other developers will have the opportunity to get paid to learn Haskell. And the cycle continues.

Many perfectly smart people also have a bias against using Haskell in production code for a business. This stems from the idea that Haskell is an academic language. They see it as unsuited towards “Real World” problems. The best rebuttal to this point is to show the many uses of Haskell in creating systems that people use every day. Now, I can sit here and point the ease of creating web servers in Haskell. I could also point to the excellent mechanisms for designing front-end UIs. But there’s still one vital area in the future of programming that I have yet to address.

This is of course, the world of AI and machine learning. AI is slowly (or not so slowly) becoming a primary concern for pretty much any software based business. The last 5-10 years have seen the rise of “cloud native” architectures and systems. But we will soon be living in age when all major software systems will use AI and machine learning at their core. In short, we are about the enter the AI Native Future, as my company’s founder put it.

This will be the first in a series of articles where I explore the uses of Haskell in writing AI applications. In the coming weeks I’ll be focusing on using the Tensor Flow bindings for Haskell. Tensor Flow allows programmers to build simple but powerful applications. There are many tutorials in Python, but the Haskell library is still in early stages. So I'll go through the core concepts of this library and show their usage in Haskell.

But for me it’s not enough to show that we can use Haskell for AI applications. That’s hardly going to move the needle or change the status quo. My ultimate goal is to prove that it’s the best tool for these kinds of systems. But first, let’s get an idea of where AI is being used, and why it’s so important.

AI Will be Everywhere...

...and this doesn’t seem to be particularly controversial. Year-on-year there are more and more discoveries in AI research. We are now able to solve problems that were scarcely thinkable a few years ago. Advancements in NLP systems like IBM Watson have made it so that chatbots are popping up all over the place. Tensor Flow has put advanced deep learning techniques at the finger tips of every programmer. Systems are getting faster and faster.

On top of that, the implications for the general public are becoming more well known. Self-driving cars are roaming the streets in several major American cities. The idea of autonomous Amazon drones making deliveries seems a near certainty. Millions of people entrust part of their home systems to joint IoT/AI devices like Nest. AI is truly going to be everywhere soon.

AI Needs to be Safe

But the ubiquity of AI presents a large number of concerns as well. Software engineers are going to have a lot more responsibility. Our code will have to make difficult and potentially life-altering decisions. For instance, the design of self-driving cars carries many philosophical implications.

The plot thickens even more when we mix AI with the Internet of Things, another exploding market. In the last year, an attack brought down large parts of the internet using a bot-net of IoT devices. In the world of IoT, security still does not have paramount importance. But soon, more and more people will have cameras, audio recording devices, fire alarms and security systems hooked up to the internet. When this happens, their safety and privacy will depend on IoT security.

The need for safety and security suggest we may need to re-think some software paradigms. "Move fast and break things" is the prevailing mindset in many quarters. But this idea doesn't look so good if "breaking things" means someone's house burns.

Pure, Functional Programming is the Path to Safety

So how does this relate to Haskell? Well let’s consider the tradeoffs we face when we choose what language to develop in. Haskell, with its strong type system and compile-time guarantees is more reliable. We can catch a lot more errors at compile time when compared to languages like Javascript, or Python. In these languages, non-syntactic issues tend to only pop up at runtime. Programmers must lean even more heavily on testing systems to catch possible errors. But testing is difficult, and there’s still plenty of disagreement about the best methodologies.

The flip side of this is that it’s somewhat easier to write code in a language like Javascript. It's easier to cut corners in the type system and have more “dynamic” objects. So while we all want “reliable” software, we’re often willing to compromise to get code off the ground faster. This is the epitome of the "Move fast and break things" mindset.

However, the explosion in the safety concerns of our software has elevated the stakes. If someone’s web browser crashes from Javascript, it's no big deal. The user will reload the page and hopefully not trigger that condition again. If your app stops responding, your user might get frustrated and you’ll lose a customer. But when programming starts penetrating other markets, any error could be catastrophic. If a self driving car encounters a sudden bug and the system crashes, many people could die. So it is our ethical responsibility to figure out ways to make it impossible for our software to encounter these kinds of errors.

Haskell is in many respects a very safe language. This is why it’s trusted by large financial institutions, large data science firms, and even by a company working in autonomous flight control. When your code cannot have arbitrary side effects, it is far easier to prevent it from crashing. It is also easier to secure a system (like an IoT device) when you can prevent leaks from arbitrary effects. Often these techniques are present in Haskell but not other languages.

The field of dependent types is yet another area where we’ll be able to add more security to our programming. They'll enable even more compile-time guarantees of behavior. This can add a great deal of safety when used well. Haskell doesn’t have full support for dependent types yet, but it is in the works. In the meantime there are languages like Idris with first class support.

Of course, when it comes to AI and deep learning, getting these guarantees will be difficult. It's one thing to build a type that ensures you have a vector of a particular length. It's quite another to build a type ensuring that when your car sees a dozen people standing ahead of it, it must brake. But these are the sorts of challenges programmers will need to face in the AI Native Future. And if we want to ensure Haskell’s place in that future, we’ll have to show these results are possible.

Conclusion

It's obvious that AI and machine learning are the big fields of software engineering. They’ll continue to dominate the field for a long time. They can have an incredible impact on our lives. But by allowing this impact, we’re putting more of our safety in the hands of software engineers. This has major implications for how we develop software.

We often have to make tradeoffs between ease of development and reliability. A language like Haskell can offer us a lot of compile time guarantees about the behavior of our program. These guarantees are absent from many other languages. However, achieving these guarantees can introduce more pain into the development process.

But soon our code will be controlling things like self-driving cars, delivery drones, and home security devices. So we have an ethical responsibility to do everything in our power to make our code as reliable as possible. For this reason, Haskell is in a prime position when it comes to the AI Native Future. To take advantage of this, it will require a lot of work. Haskell programmers will have to develop language tools like dependent types to make Haskell even more reliable. We'll also have to contribute to libraries that will make it easy to write machine learning applications in Haskell.

With this in mind, the next few articles on this blog are all going to focus on using Haskell for machine learning. We’ll be starting by going through the basics of the Tensor Flow bindings for Haskell. You can get a sneak peek at some of that content by downloading our Haskell Tensor Flow tutorial!

If you’ve never programmed in Haskell before, you should try it out! We have two great resources for getting started. First, there’s the Getting Started Checklist. It will first walk you through downloading the language. Then it will point you in the directions of some other beginner materials. Second, there’s our Stack mini-course. This will walk you through the Stack tool, which makes it very easy to build projects, organize code, and get dependencies with Haskell.

Read More
James Bowen James Bowen

Coping with (Code) Failures

Exception handling is annoying. It would be completely unnecessary if everything in the world worked the way it's supposed to. But of course that would be a fantasy. Haskell can’t change reality. But its error facilities are a lot better than most languages. This week we'll look at some common error handling patterns. We’ll see a couple clear instances where Haskell has simpler and cleaner code.

Using Nothing

The most basic example we have in Haskell is the Maybe type. This allows us to encapsulate any computation at all with the possibility of failure. Why is this better than similar ideas in other languages? Well let’s take Java for example. It’s easy enough to encapsulate Maybe when you’re dealing with pointer types. You can use the “null” pointer to be your failure case.

public MyObject squareRoot(int x) {
  if (x < 0) {
    return nil;
  } else {
    return MyObject(Math.sqrt(x));
  }
}

But this has a few disadvantages. First of all, null pointers (in general) look the same as regular pointers to the type checker. This means you get no compile time guarantees that ANY of the pointers you’re dealing with aren’t null. Imagine if we had to wrap EVERY Haskell value in a Maybe. We would need to CONSTANTLY unwrap them or else risk tripping a “null-pointer-exception”. In Haskell, once we’ve handled the Nothing case once, we can pass on a pure value. This allows other code to know that it will not throw random errors. Consider this example. We check that our pointer in non-null once already in function1. Despite this, good programming practice dictates that we perform another check in function2.

public void function1(MyObject obj) {
  if (obj == null) {
    // Deal with error
  } else {
    function2(obj);
  }
}

public void function2(MyObject obj) {
  if (obj == null) {
    // ^^ We should be able to make this call redundant
  } else {
    // …
  }
}

The second sticky point comes up when we’re dealing with non-pointer, primitive values. We often don't have a good way to handle these cases. Suppose your function returns an int, but it might fail. How do you represent failure? It’s not uncommon to see side cases like this handled by using a “sentinel” value, like 0 or -1.

But if the range of your function spans all the integers, you’re a little bit stuck there. The code might look cleaner if you use an enumerated type, but this doesn’t avoid the problem. The same problem can even crop up with pointer values if null is valid in the particular context.

public int integerSquareRoot(int x) {
  if (x < 0) {
    return -1;
  } else {
    return Math.round(Math.sqrt(x));
  }
}

public void (int a) {
  int result = integerSquareRoot(a);
  if (result == -1) {
    // Deal with error
  } else {
    // Use correct value
  }
}

Finally, monadic composition with Maybe is much more natural in Haskell. There are many examples of this kind of spaghetti in Java code:

public Result computation1(MyObject value) {
  …
}

public Result computation2(Result res) {
  …
}

public int intFromResult(Result res) {
  …
}

public int spaghetti(MyObject value) {
  if (value != null) {
    result1 = computation1(value);
    if (result1 != null) {
      result2 = computation2(result1);
      if (result2 != null) {
        return intFromResult(result2);
      }
    }
  }
  return -1;
}

Now if we’re being naive, we might end up with a not-so-pretty version ourselves:

computation1 :: MyObject -> Maybe Result
computation2 :: Result -> Maybe Result
intFromResult :: Result -> Int

spaghetti :: Maybe MyObject -> Maybe Int
spaghetti value = case value of
  Nothing -> Nothing
  Just realValue -> case computation1 realValue of
    Nothing -> Nothing
    Just result1 -> case computation2 result1 of
      Nothing -> Nothing
      Just result2 -> return $ intFromResult result2

But as we discussed in our first Monads article, we can make this much cleaner. We'll compose our actions within the Maybe monad:

cleanerVersion :: Maybe MyObject -> Maybe Int
cleanerVersion value = do
  realValue <- value
  result1 <- computation1 realValue
  result2 <- computation2 result1
  return $ intFromResult result2

Using Either

Now suppose we want to make our errors contain a bit more information. In the example above, we’ll output Nothing if it fails. But code calling that function will have no way of knowing what the error actually was. This might hinder our code's ability to correct the error. We'll also have no way of reporting a specific failure to the user. As we’ve explored, Haskell’s answer to this is the Either monad. This allows us to attach a value of any type as a possible failure. In this case, we'll change the type of each function. We would then update the functions to use a descriptive error message instead of returning Nothing.

computation1 :: MyObject -> Either String Result
computation2 :: Result -> Either String Result
intFromResult :: Result -> Int

eitherVersion :: Either String MyObject -> Either String Int
eitherVersion value = do
  realValue <- value
  result1 <- computation1 realValue
  result2 <- computation2 result1
  return $ intFromResult result2

Now suppose we want to try to make this happen in Java. How do we do this? There are a few options I’m aware of. None of them are particularly appetizing.

  1. Print an error message when the failure condition pops up.
  2. Update a global variable when the failure condition pops up.
  3. Create a new data type that could contain either an error value or a success value.
  4. Add a parameter to the function whose value is filled in with an error message if failure occurs.

The first couple rely on arbitary side effects. As Haskell programmers we aren’t fans of those. The third option would require messing with Java’s template types. These are far more difficult to work with than Haskell’s parameterized types. If we don't take this approach, we'd need a new type for every different return value.

The last method is a bit of an anti-pattern, making up for the fact that tuples aren’t a first class construct in Java. It’s quite counter-intuitive to check one of your input values for what do as an an output result. So with these options, give me Haskell any day.

Using Exceptions and Handlers

Now that we understand the more “pure” ways of handling error cases in our code, we can deal with exceptions. Exceptions show up in almost every major programming language; Haskell is no different. Haskell has the SomeException type that encapsulates possible failure conditions. It can wrap any type that is a member of the Exception typeclass. You'll generally be creating your own exception types.

Generally, we throw exceptions when we want to state that a path of code execution has failed. Instead of returning some value to the calling function, we'll allow completely different code to handle the error. If this sounds convoluted, that’s because it kind’ve is. In general you want to prefer keeping the control flow as clear as possible. Sometimes though we cannot avoid it.

So let’s suppose we’re calling a function we know might throw a particular exception. We can “handle” that exception by attaching a handler. In Java, you do this pattern like so:

public int integerSquareRoot(int value) throws NegativeSquareRootException {
  ...
}

public int mathFunction(int x) {
  try {
    return 2 * squareRoot(x);
  } catch (NegativeSquareRootException e) {
    // Deal with invalid result
  }
}

To handle exceptions in this manner in Haskell, you have to have access to the IO monad. The most general way to handle exceptions is to use the catch function. When you call the action that might throw the exception, you include a “handler” function. This function will take the exception as an argument and deal with the case. If we want to write the above example in Haskell, we should first define our exception type. We only need to derive Show to also derive an instance for the Exception typeclass:

import Control.Exception (Exception)

data MyException = NegativeSquareRootException
  deriving (Show)

instance Exception MyException

Now we can write a pure function that will throw this exception in the proper circumstances.

import Control.Exception (Exception, throw)

integerSquareRoot :: Int -> Int
integerSquareRoot x
  | x < 0 = throw NegativeSquareRootException
  | otherwise = undefined

While we can throw the exception from pure code, we need to be in the IO monad to catch it. We’ll do this with the catch function. We’ll use a handler function that will only catch the specific error we’re expecting. It will print the error as a message and then return a dummy value.

import Control.Exception (Exception, throw, catch)

…
mathFunction :: Int -> IO Int
mathFunction input = do
  catch (return $ integerSquareRoot input) handler
  where
    handler :: MyException -> IO Int
    handler NegativeSquareRootException = 
      print "Can't call square root on a negative number!" >> return (-1)

MonadThrow

We can also generalize this process a bit to work in different monads. The MonadThrow typeclass allows us to specify different exceptional behaviors for different monads. For instance, Maybe throws exceptions by using Nothing. Either uses Left, and IO will use throwIO. When we’re in a general MonadThrow function, we throw exceptions with throwM.

callWithMaybe :: Maybe Int
callWithMaybe = integerSquareRoot (-5) -- Gives us `Nothing`

callWithEither :: Either SomeException Int
callWithEither = integerSquareRoot (-5) -- Gives us `Left NegativeSquareRootException`

callWithIO :: IO Int
callWithIO = integerSquareRoot (-5) -- Throws an error as normal

integerSquareRoot :: (MonadThrow m) => Int -> m Int
integerSquareRoot x
  | x < 0 = throwM NegativeSquareRootException
  | otherwise = ...

There is some debate about whether the extra layers of abstraction are that helpful. There is a strong case to be made that if you’re going to be using exceptional control flow, you should be using IO anyway. But using MonadThrow can make your code more extensible. Your function might be usable in more areas of your codebase. I’m not too opinionated on this topic (not yet at least). But there are certainly some strong opinions within the Haskell community.

Summary

Error handling is tricky business. A lot of the common programming patterns around error handling are annoying to write. Luckily, Haskell has several different means of doing this. In Haskell, you can express errors using simple mechanisms like Maybe and Either. Their monadic behavior gives you a high degree of composability. You can also throw and catch exceptions like you can in other languages. But Haskell has some more general ways to do this. This allows you to be agnostic to how functions within your code handle errors.

New to Haskell? Amazed by its awesomeness and want to try? Download our Getting Started Checklist! It has some awesome tools and instructions for getting Haskell on your computer and starting out.

Have you tried Haskell but want some more practice? Check out our Recursion Workbook for some great content and 10 practice problems!

And stay tuned to the Monday Morning Haskell blog!

Read More
James Bowen James Bowen

Getting the User's Opinion: Options in Haskell

GUI's are hard to create. Luckily for us, we can often get away with making our code available through a command line interface. As you start writing more Haskell programs, you'll probably have to do this at some point.

This article will go over some of the ins and outs of CLI’s. In particular, we’ll look at the basics of handling options. Then we'll see some nifty techniques for actually testing the behavior of our CLI.

A Simple Sample

To motivate the examples in this article, let’s design a simple program We’ll have the user input a message. Then we’ll print the message to a file a certain number of times and list the user’s name as the author at the top. We’ll also allow them to uppercase the message if they want. So we’ll get five pieces of input from the user:

  1. The filename they want
  2. Their name to place at the top
  3. Whether they want to uppercase or not
  4. The message
  5. The repetition number

We’ll use arguments and options for the first three pieces of information. Then we'll have a command line prompt for the other two. For instance, we’ll insist the user pass the expected file name as an argument. Then we’ll take an option for the name the user wants to put at the top. Finally, we’ll take a flag for whether the user wants the message upper-cased. So here are a few different invocations of the program.

>> run-cli “myfile.txt” -n “John Doe”
What message do you want in the file?
Sample Message
How many times should it be repeated?
5

This will print the following output to myfile.txt:

From: John Doe
Sample Message
Sample Message
Sample Message
Sample Message
Sample Message

Here’s another run, this time with an error in the input:

>> run-cli “myfile2.txt” -n “Jane Doe” -u
What message do you want in the file?
A new message
How many times should it be repeated?
asdf
Sorry, that isn't a valid number. Please enter a number.
3

This file will look like:

From: Jane Doe
A NEW MESSAGE
A NEW MESSAGE
A NEW MESSAGE

Finally, if we don’t get the right arguments, we should get a usage error:

>> run-cli
Missing: FILENAME -n USERNAME

Usage: CLIPractice-exe FILENAME -n USERNAME [-u]
  Comand Line Sample Program

Getting the Input

So the most important aspect of the program is getting the message and repetitions. We’ll ignore the options for now. We’ll print a couple messages, and then use the getLine function to get their input. There’s no way for them to give us a bad message, so this section is easy.

getMessage :: IO String
getMessage = do
  putStrLn "What message do you want in the file?"
  getLine

But they might try to give us a number we can’t actually parse. So for this task, we’ll have to set up a loop where we keep asking the user for a number until they give us a good value. This will be recursive in the failure case. If the user won’t enter a valid number, they’ll have no choice but to terminate the program by other means.

getRepetitions :: IO Int
getRepetitions = do
  putStrLn "How many times should it be repeated?"
  getNumber

getNumber :: IO Int
getNumber = do
  rep <- getLine
  case readMaybe rep of
    Nothing -> do
      putStrLn "Sorry, that isn't a valid number. Please enter a number."
      getNumber
    Just i -> return i

Once we’re doing reading the input, we’ll print the output to a file. In this instance, we hard-code all the options for now. Here’s the full program.

import Data.Char (toUpper)
import System.IO (writeFile)
import Text.Read (readMaybe)

runCLI :: IO ()
runCLI = do
  let fileName = "myfile.txt"
  let userName = "John Doe"
  let isUppercase = False
  message <- getMessage
  reps <- getRepetitions
  writeFile fileName (fileContents userName message reps isUppercase)

fileContents :: String -> String -> Int -> Bool -> String
fileContents userName message repetitions isUppercase = unlines $
  ("From: " ++ userName) :
  (replicate repetitions finalMessage)
  where
    finalMessage = if isUppercase 
     then map toUpper message 
    else message

Parsing Options

Now we have to deal with the question of how we actually parse the different options. We can do this by hand with the getArgs function, but this is somewhat error prone. A better option in general is to use the Options.Applicative library. We’ll explore the different possibilities this library allows. We’ll use three different helper functions for the three pieces of input we need.

The first thing we’ll do is build a data structure to hold the different options we want. We want to know the file name to store at, the name at the top, and the uppercase status.

data CommandOptions = CommandOptions
  { fileName :: FilePath
  , userName :: String
  , isUppercase :: Bool }

Now we need to parse each of these. We’ll start with the uppercase value. The most simple parser we have is the flag function. It tells us if a particular flag (we’ll call it -u) is present, we’ll uppercase the message, otherwise not. It gets coded like this with the Options library:

uppercaseParser :: Parser Bool
uppercaseParser = flag False True (short 'u')

Notice we use short in the final argument to denote the flag character. We could also use the switch function, since this flag is only a boolean, but this version is more general.

Now we’ll move on to the argument for the filename. This uses the argument helper function. We’ll use a string parser (str) to ensure we get the actual string. We won’t worry about the filename having a particular format here. Notice we add some metadata to this argument. This tells the user what they are missing if they don’t use the proper format.

fileNameParser :: Parser String
fileNameParser = argument str (metavar "FILENAME")

Finally, we’ll deal with the option of what name will go at the top. We could also do this as an argument, but let’s see what the option is like. An argument is a required positional parameter. An option on the other hand comes after a particular flag. We also add metadata here for a better error message as well. The short piece of our metadata ensures it will use the option character we want.

userNameParser :: Parser FilePath
userNameParser = option str (short 'n' <> metavar "USERNAME")

Now we have to combine these different parsers and add a little more info about our program.

import Options.Applicative (execParser, info, helper, Parser, fullDesc, 
  progDesc, short, metavar, flag, argument, str, option)

parseOptions :: IO CommandOptions
parseOptions = execParser $ info (helper <*> commandOptsParser) commandOptsInfo
  where
    commandOptsParser = CommandOptions <$> fileNameParser <*> userNameParser <*> uppercaseParser
    commandOptsInfo = fullDesc <> progDesc "Command Line Sample Program"

-- Revamped to take options
runCLI :: CommandOptions -> IO ()
runCLI commandOptions = do
  let file = fileName commandOptions
  let user = userName commandOptions
  let uppercase = isUppercase commandOptions
  message <- getMessage
  reps <- getRepetitions
  writeFile file (fileContents user message reps uppercase)

And now we’re done! We build our command object using these three different parsers. We chain the operations together using applicatives! Then we pass the result to our main program. If you aren’t too familiar with functors, and applicatives, we went over these a while ago on the blog. Refresh your memory!

IO Testing

Now we have our program working, we need to ask ourselves how we test its behavior. We can do manual command line tests ourselves, but it would be nice to have an automated solution. The key to this is the Handle abstraction.

Let’s first look at some basic file handling types.

openFile :: FilePath -> IO Handle
hGetLine :: Handle -> IO String
hPutStrLn :: Handle -> IO ()
hClose :: Handle -> IO ()

Normally when we write something to a file, we open a handle for it. We use the handle (instead of the string literal name) for all the different operations. When we’re done, we close the handle.

The good news is that the stdin and stdout streams are actually the exact same Handle type under the hood!

stdin :: Handle
stdout :: Handle

How does this help us test? The first step is to abstract away the handles we’re working with. Instead of using print and getLine, we’ll want to use hGetLine and hPutStrLn. Then we take these parameters as arguments to our program and functions. Let’s look at our reading functions:

getMessage :: Handle -> Handle -> IO String
getMessage inHandle outHandle = do
  hPutStrLn outHandle "What message do you want in the file?"
  hGetLine inHandle

getRepetitions :: Handle -> Handle -> IO Int
getRepetitions inHandle outHandle = do
  hPutStrLn outHandle "How many times should it be repeated?"
  getNumber inHandle outHandle

getNumber :: Handle -> Handle -> IO Int
getNumber inHandle outHandle = do
  rep <- hGetLine inHandle
  case readMaybe rep of
    Nothing -> do
      hPutStrLn outHandle "Sorry, that isn't a valid number. Please enter a number."
      getNumber inHandle outHandle
    Just i -> return i

Once we’ve done this, we can make the input and output handles parameters to our program as follows. Our wrapper executable will pass stdin and stdout:

-- Library File:
runCLI :: Handle -> Handle -> CommandOptions -> IO ()
runCLI inHandle outHandle commandOptions = do
  let file = fileName commandOptions
  let user = userName commandOptions
  let uppercase = isUppercase commandOptions
  message <- getMessage inHandle outHandle
  reps <- getRepetitions inHandle outHandle
  writeFile file (fileContents user message reps uppercase)

-- Executable File
main :: IO ()
main = do
  options <- parseOptions
  runCLI stdin stdout options

Now our library API takes the handles as parameters. This means in our testing code, we can pass whatever handle we want to test the code. And, as you may have guessed, we’ll do this with files, instead of stdin and stdout. We’ll make one file with our expected terminal output:

What message do you want in the file?
How many times should it be repeated?

We’ll make another file with our input:

Sample Message
5

And then the file we expect to be created:

From: John Doe
Sample Message
Sample Message
Sample Message
Sample Message
Sample Message

Now we can write a test calling our library function. It will pass the expected arguments object as well as the proper file handles. Then we can compare the output of our test file and the output file.

import Lib

import System.IO
import Test.HUnit

main :: IO ()
main = do
  inputHandle <- openFile "input.txt" ReadMode
  outputHandle <- openFile "terminal_output.txt" WriteMode
  runCLI inputHandle outputHandle options
  hClose inputHandle
  hClose outputHandle
  expectedTerminal <- readFile "expected_terminal.txt"
  actualTerminal <- readFile "terminal_output.txt"
  expectedFile <- readFile "expected_output.txt"
  actualFile <- readFile "testOutput.txt"
  assertEqual "Terminal Output Should Match" expectedTerminal actualTerminal
  assertEqual "Output File Should Match" expectedFile actualFile

options :: CommandOptions
options = CommandOptions "testOutput.txt" "John Doe" False

And that’s it! We can also use this process to add tests around the error cases, like when the user enters invalid numbers.

Summary

Writing a command line interface isn't always the easiest task. Getting a user’s input sometimes requires creating loops if they won’t give you the information you want. Then dealing with arguments can be a major pain. The Options.Applicative library contains many option parsing tools. It helps you deal with flags, options, and arguments. When you're ready to test your program, you'll want to abstract the file handles away. You can use stdin and stdout from your main executable. But then when you test, you can use files as your input and output handles.

Want to try writing a CLI but don't know Haskell yet? No sweat! Download our Getting Started Checklist and get going learning the language!

When you're making a full project with executables and test suites, you need to keep organized! Take our FREE Stack mini-course to learn how to organize your Haskell with Stack.

Read More
James Bowen James Bowen

Cleaning Up Our Projects with Hpack!

About a month ago, we released our FREE Stack mini-course. If you've never used Stack before, you should totally check out that course! This article will give you a sneak peak at some of the content in the course.

But if you're already familiar with the basics of Stack and don't think you need the course, don't worry! In this article we'll be going through another cool tool to streamline your workflow!

Most any Haskell project you create will use Cabal under the hood. Cabal performs several important tasks for you. It downloads dependencies for you and locates the code you wrote within the file system. It also links all your code so GHC can compile it. In order for Cabal to do this, you need to create a .cabal file describing all these things. It needs to know for instance what libraries each section of your code depends on. You'll also have to specify where the source directories are on your file system.

The organization of a .cabal file is a little confusing at times. The syntax can be quite verbose. We can make our lives simpler though if we use the "Hpack" tool. Hpack allows you to specify your project organization in a more concise format. Once you’ve specified everything in Hpack’s format, you can generate the .cabal file with a single command.

Using Hpack

The first step to using hpack if of course to download it. This is simple, as you long as you have installed Stack on your system. Use the command:

stack install hpack

This will install Hpack system wide so you can use it in all your projects. The next step is to specify your code’s organization in a file called package.yaml. Here’s a simple example:

name: HpackExampleProject

version: 0.1.0.0

ghc-options: -Wall

dependencies:
  - base

library:
  source-dirs: src/

executables:
  HpackExampleProject-exe:
    main: Main.hs
    source-dirs: app/
    dependencies:
      HpackExampleProject

tests:
  HpackExampleProject-test:
    main: Spec.hs
    source-dirs: test/ 
    dependencies:
      HpackExampleProject

This example will generate a very simple cabal file. It'll look a lot like the default template of running stack new HpackExampleProject. There are a few basic fields, like the name, version and compiler options for our project. We can specify details about our code library, executables, and any test suites we have, just as we can in a .cabal file. Each of these components can have their own dependencies. We can also specify global dependencies.

Once we have created this file, all we need to do is run the hpack command from the directory containing this file. This will generate our .cabal file:

>> hpack
generated HpackExampleProject.cabal

What problems does Hpack solve?

One big problem hpack solves is module inference. Your .cabal file should specify all Haskell modules which are part of your library. You'll always have two groups: “exposed” modules and “other” modules. It can be quite annoying to list every one of these modules, and the list can get quite long as your library gets bigger! Worse, you'll sometimes get confusing errors when you create a new module but forget to add it to the .cabal file. With Hpack, you don't need to remember to add most new files! It looks at your file system and determines what modules are present for you. Suppose you have organized your files like so in your source directory:

src/Lib.hs
src/API.hs
src/Internal/Helper.hs
src/Internal/Types.hs

Using the normal .cabal approach, you would need to list these modules by hand. But without listing anything in the package.yaml file, you’ll get all your modules listed in the .cabal file:

exposed-modules:
  API
  Internal.Helper
  Internal.Types
  Lib

Now, you might not want all your modules exposed. You can make a simple modification to the package file:

library:
  source-dirs: src/
  exposed-modules:
    - API
    - Lib

And hpack will correct the .cabal file.

exposed-modules:
  API
  Lib
other-modules:
  Internal.Helper
  Internal.Types

From this point, Hpack will infer all new Haskell modules as “other” modules. You'll only need to list "exposed" modules in package.yaml. There's only one thing to remember! You need to run hpack each time you add new modules, or else Cabal will not know where your code is. This is still much easier than modifying the .cabal file each time. The .cabal file itself will still contain a long list of module names. But most of them won’t be present in the package.yaml file, which is your main point of interaction.

Another big benefit of using hpack is global dependencies. Notice in the example how we have a “dependencies” field above all the other sections. This means our library, executables, and test-suites will all get base as a dependency for free. Without hpack, we would need to specify base as a dependency in each individual section.

There is also plenty of other syntactic sugar available with hpack. One simple example is the github specification. You can put the following single line in your package file:

github: username/reponame

And you’ll get the following lines for free in your .cabal file.

homepage:          https://github.com/username/reponame#readme
bug-reports:       https://github.com/username/reponame/issues

Summary

Once you move beyond toy projects, maintenance of your package will be non-trivial. If you use hpack, you’ll have an easier time organizing your first big project. The syntax is cleaner. The organization is more intuitive. Finally, you will save yourself the stress of performing many repetitive tasks. Constant edits to the .cabal file will interrupt your flow and build process. So avoiding them should make you a lot more productive.

Now if you haven't used Stack or Cabal at all before, there was a lot to grasp here. But hopefully you're convinced that there are easy ways to organize your Haskell code! If you're intrigued at learning how, sign up for our FREE Stack mini-course! You'll learn all about the simple approaches to organizing a Haskell project.

If you've never used Haskell at all and are totally confused by all this, no need to fret! Download our Getting Started Checklist and you'll be well on your way to learning Haskell!

Read More
James Bowen James Bowen

Playing Match Maker

In last week’s article we saw an introduction to the Functional Graph Library. This is a neat little library that allows us to build graphs in Haskell. It then makes it very easy to solve basic graph problems. For instance, we could plug our graph into a function that would give us the shortest path. In another example, we found the minimum spanning tree of our graph with a single function call.

Those examples were rather contrived though. Our “input” was already in a graph format more or less, so we didn’t have to think much to convert it. Then we solved arbitrary problems without providing any real context. In programming graph algorithms often come up when you’re least expecting it! We’ll prove this with a sample problem.

Motivating Example

Suppose we’re building a house. We have several people who are working on the house, and they all have various tasks to do. The need certain tools to do these tasks. As long as a person gets a tool for one of the jobs they’re working on, they can make progress. Of course, we have a limited supply of tools. So suppose we have this set of tools:

Hammer
Hammer
Power Saw
Ladder
Ladder
Ladder
Caulking Gun

And now we have the following people working on this house who all have the following needs:

Jason, Hammer, Ladder, Caulking Gun
Amanda, Hammer
Kristina, Caulking Gun
Chad, Ladder
Josephine, Power Saw
Chris, Power Saw, Ladder
Dennis, Caulking Gun, Hammer

We want to find an assignment of people to tools such that the highest number of people has at least one of their tools. In this situation we can actually find an assignment that gets all seven people a tool:

Jason - Ladder
Amanda - Hammer
Kristina - Caulking Gun
Chad - Ladder
Josephine - Power Saw
Chris - Ladder
Dennis - Hammer

We’ll read our problem in from a handle like we did last time, and assume we first read the number of tools, then people. Our output will be the list of tools and then a map from each person’s name to the list of tools they can use.

module Tools where

import           Control.Monad (replicateM)
import           Data.List.Split (splitOn)
import           System.IO (hGetLine, Handle)

readInput :: Handle -> IO ([String], [(String, [String])])
readInput handle = do
  numTools <- read <$> hGetLine handle
  numPeople <- read <$> hGetLine handle
  tools <- replicateM numTools (hGetLine handle)
  people <- replicateM numPeople (readPersonLine handle)
  return (tools, people)

readPersonLine :: Handle -> IO (String, [String]) 
readPersonLine handle = do
  line <- hGetLine handle
  let components = splitOn ", " line
  return (head components, tail components)

Some Naive Solutions

Now our first guess might be to try a greedy algorithm. We’ll iterate through the list of tools, find the first person in the list who can use that tool, and recurse on the rest. This might look a little like this:

solveToolsGreedy :: Handle -> IO Int
solveToolsGreedy handle = do
  (tools, personMap) <- readInput handle
  return $ findMaxMatchingGreedy tools (Map.toList personMap)

findMaxMatchingGreedy :: [String] -> [(String, [String])] -> Int 
findMaxMatchingGreedy [] _ = 0 -- No more tools to match
findMaxMatchingGreedy (tool : rest) personMap = case break (containsTool tool) personMap of
  (allPeople, []) -> findMaxMatchingGreedy rest personMap -- Can't match this tool
  (somePeople, (_ : otherPeople)) -> 1 + findMaxMatchingGreedy rest (somePeople ++ otherPeople)

containsTool :: String -> (String, [String]) -> Bool
containsTool tool pair = tool `elem` (snd pair)

Unfortunately, this could lead to some sub-optimal outcomes. In this case, our greed might cause us to assign the caulking gun to Jason, and then Kristina won’t be able to use any tools.

So now let’s try and fix this by during 2 recursive calls! We’ll find the first person we can assign the tool to (or otherwise drop the tool). Once we’ve done this, we’ll imagine two scenarios. In case 1, this person will use the tool, so we can remove the tool and the person from our lists. Then we'll recurse on the remainder, and add 1. In case 2, this person will NOT use the tool, so we’ll recurse except REMOVE the tool from that person’s list.

findMaxMatchingSlow :: [String] -> [(String, [String])] -> Int
findMaxMatchingSlow [] _ = 0
findMaxMatchingSlow allTools@(tool : rest) personMap = 
  case break (containsTool tool) personMap of
    (allPeople, []) -> findMaxMatchingGreedy rest personMap -- Can't match this tool
    (somePeople, (chosen : otherPeople)) -> max useIt loseIt
      where
        useIt = 1 + findMaxMatchingSlow rest (somePeople ++ otherPeople)
        loseIt = findMaxMatchingSlow allTools newList
        newList = somePeople ++ (modifiedChosen : otherPeople)
        modifiedChosen = dropTool tool chosen

dropTool :: String -> (String, [String]) -> (String, [String])
dropTool tool (name, validTools) = (name, delete tool validTools)

The good news is that this will get us the optimal solution! It solves our simple case quite well! The bad news is that it will take too long on more difficult cases. A naive use-it-or-lose-it algorithm like this will take exponential time (O(2^n)). This means even for modest input sizes (~100) we’ll be waiting for a loooong time. Anything much larger takes prohibitively long. Plus, there’s no way for us to memoize the solution here.

Graphs to the Rescue!

So at this point, are we condemned to choose between a fast inaccurate algorithm and a correct but slow one? In this case the answer is no! This problem is actually best solved by using a graph algorithm! This is an example of what’s called a “bipartite matching” problem. We’ll create a graph with two sets of nodes. On the left, we’ll have a node for each tool. On the right, we’ll have a node for each person. The only edges in our graph will go from nodes on the left towards nodes on the right. A “tool” node has an edge to a “person” node if that person can use the tool. Here’s a partial representation of our graph (plus or minus my design skills). We’ve only drawn in the edges related to Amanda, Christine and Josephine so far.

Now we want to find the “maximum matching” in this graph. That is, we want the largest set of edges such that no two edges share a node. The way to solve this problem using graph algorithms is to turn it into yet ANOTHER graph problem! We’ll add a node on the far left, called the “source” node. We’ll connect it to every “tool” node. Now we’ll add a node on the far right, called the “sink” node. It will receive an edge from every “person” node. All the edges in this graph have a distance of 1.

Again, most of the middle edges are missing here.

Again, most of the middle edges are missing here.

The size of the maximum matching in this case is equal to the “max flow” from the source node to the sink node. This is a somewhat advanced concept. But imagine there is water gushing out of the source node and that every edge is a “pipe” whose value (1) is the capacity. We want the largest amount of water that can go through to the sink at once.

Last week we saw built-in functions for shortest path and min spanning tree. FGL also has an out-of-the-box solution for max flow. So our main goal now is to take our objects and construct the above graph.

Preparing Our Solution

A couple weeks ago, we created a segment tree that was very specific to the problem. This time, we’ll show what it’s like to write a more generic algorithm. Throughout the rest of the article, you can imagine that a is a tool, and b is a person. We’ll write a general maxMatching function that will take a list of a’s, a list of b’s, AND a predicate function. This function will determine whether an a object and a b object should have an edge between them. We’ll use the containsTool function from above as our predicate. Then we'll call our general function.

findMaxMatchingBest :: [String] -> [(String, [String])] -> Int
findMaxMatchingBest tools personMap = findMaxMatching containsTool tools personMap

…(different module)

findMaxMatching :: (a -> b -> Bool) -> [a] -> [b] -> Int
findMaxMatching predicate as bs = ...

Building Our Graph

To build our graph, we’ll have to decide on our labels. Once again, we’ll only label our edges with integers. In fact, they’ll all have a “capacity” label of 1. But our nodes will be a little more complicated. We’ll want to associate the node with the object, and we have a heterogeneous (and polymorphic) set of items. We’ll make this NodeLabel type that could refer to any of the four types of nodes:

data NodeLabel a b = 
  LeftNode a |
  RightNode b |
  SourceNode |
  SinkNode

Next we’ll start building our graph by constructing the inner part. We’ll make the two sets of nodes as well as the edges connecting them. We’ll assign the left nodes to the indices from 1 up through the size of that list. And then the right nodes will take on the indices from one above the first list's size through the sum of the list sizes.

createInnerGraph 
  :: (a -> b -> Bool) 
  -> [a]
  -> [b]
  -> ([LNode (NodeLabel a b)], [LNode (NodeLabel a b)], [LEdge Int])
createInnerGraph predicate as bs = ...
  where
    sz_a = length as
    sz_b = length bs
    aNodes = zip [1..sz_a] (LeftNode <$> as)
    bNodes = zip [(sz_a + 1)..(sz_a + sz_b)] (RightNode <$> bs)

Next we’ll also make tuples matching the index to the item itself without its node label wrapper. This will allow us to call the predicate on these items. We’ll then get all the edges out by using a list comprehension. We'll pull each pairing and determining if the predicate holds. If it does, we’ll add the edge.

where
  ...
  indexedAs = zip [1..sz_a] as
  indexedBs = zip [(sz_a + 1)..(sz_a + sz_b)] bs
  nodesAreConnected (_, aItem) (_, bItem) = predicate aItem bItem
  edges = [(fst aN, fst bN, 1) | aN <- indexedAs, bN <- indexedBs, nodesAreConnected aN bN]

Now we’ve got all our pieces, so we combine them to complete the definition:

createInnerGraph predicate as bs = (aNodes, bNodes, edges)

Now we’ll construct the “total graph”. This will include the source and sink nodes. It will include the indices of these nodes in the return value so that we can use them in our algorithm:

totalGraph :: (a -> b -> Bool) -> [a] -> [b] 
  -> (Gr (NodeLabel a b) Int, Int, Int)

Now we’ll start our definition by getting all the pieces out of the inner graph as well as the size of each list. Then we’ll assign the index for the source and sink to be the numbers after these combined sizes. We’ll also make the nodes themselves and give them the proper labels.

totalGraph predicate as bs = ...
  where
    sz_a = length as
    sz_b = length bs
    (leftNodes, rightNodes, middleEdges) = createInnerGraph predicate as bs
    sourceIndex = sz_a + sz_b + 1
    sinkIndex = sz_a + sz_b + 2
    sourceNode = (sourceIndex, SourceNode)
    sinkNode = (sinkIndex, SinkNode)

Now to finish this definition, we’ll first create edges from the source out to the right nodes. Then we'll make edges from the left nodes to the sink. We’ll also use list comprehensions there. Then we’ll combine all our nodes and edges into two lists.

where
  ...
  sourceEdges = [(sourceIndex, lIndex, 1) | lIndex <- fst <$> leftNodes]
  sinkEdges = [(rIndex, sinkIndex, 1) | rIndex <- fst <$> rightNodes]
  allNodes = sourceNode : sinkNode : (leftNodes ++ rightNodes)
  allEdges = sourceEdges ++ middleEdges ++ sinkEdges

Finally, we’ll complete the definition by making our graph. As we noted, we'll also return the source and sink indices:

totalGraph predicate as bs = (mkGraph allNodes allEdges, sourceIndex, sinkIndex)
  where
    ...

The Grand Finale

OK one last step! We can now fill in our findMaxMatching function. We’ll first get the necessary components from building the graph. Then we’ll call the maxFlow function. This works out of the box, just like sp and msTree from the last article!

import Data.Graph.Inductive.Graph (LNode, LEdge, mkGraph)
import Data.Graph.Inductive.PatriciaTree (Gr)
import Data.Graph.Inductive.Query.MaxFlow (maxFlow)

findMaxMatching :: (a -> b -> Bool) -> [a] -> [b] -> Int
findMaxMatching predicate as bs = maxFlow graph source sink
  where
    (graph, source, sink) = totalGraph predicate as bs

And we’re done! This will always give us the correct answer and it runs very fast! Take a look at the code on Github if you want to experiment with it!

Conclusion

Whew algorithms are exhausting aren’t they? That was a ton of code we just wrote. Let’s do a quick review. So this time around we looked at an actual problem that was not an obvious graph problem. We even tried a couple different algorithmic approaches. They both had issues though. Ultimately, we found that a graph algorithm was the solution, and we were able to implement it with FGL.

If you want to use FGL (or most any awesome Haskell library), it would help a ton if you learned how to use Stack! This great tool wraps project organization and package management into one tool. Check out our FREE Stack mini-course and learn more!

If you’ve never programmed in Haskell at all, then what are you waiting for? It’s super fun! You should download our Getting Started Checklist for some tips and resources on starting out!

Stay tuned next week for more on the Monday Morning Haskell Blog!

Read More
James Bowen James Bowen

Graphing it Out

In the last two articles we dealt with some algorithmic issues. We were only able to solve these by improving the data structures in our code. First we used Haskell's built in array type for fast indexing. Then when we needed a segment tree, and we decided to make it from scratch. But we can’t roll our own data structure for every problem we encounter. So it’s good to have some systems we can use for more of these advanced topics.

One of the most important categories of data structures we’ll need for algorithms is graphs. Graphs are quite powerful when it comes to representing complicated problems. They are very useful for expressing relationships between data points. In this article, we'll see two types of graph problems. We’ll learn all about a library called the Functional Graph Library (FGL) that is available on Hackage for us to use. We’ll then take a stab at constructing graphs using the library. Finally, we’ll see how simple it is to solve these algorithms using this library once we’ve made our graphs.

For a complete set of the code that we’ll use in this article, check out this Github repository. It’ll show you how you can use Stack to bring the Functional graph library into your code.

If you’ve never used Stack before, it’s an indispensible tool for creating programs in Haskell. You should try out our Stack mini-course and learn more about it.

Graphs 101

For those of you who aren’t familiar with graphs and graph algorithms, I’ll explain a few of the basics here. If you’re very familiar with these already, you can skip this section. A graph contains a series of objects and encodes various relationships between these objects. For each object in our set, we have a “node” in our graph. These are like data points. Then to represent every relationship, we create an “edge” between two different nodes. We often give some kind of value to this edge as a piece of information about the relationship. In this article, we’ll imagine our nodes as places on a map, with the edges representing legal routes between locations. The label of each edge is the distance.

Edges can be both “directed” and “undirected”. Directed edges describe a 1-way relationship between the nodes. Undirected edges describe a 2-way relationship. Here's an example of a graph with directed edges:

Graph Basic (1).png

Making Graphs with FGL

So whenever we’re making a graph with FGL, we’ll use a two-step process. First, we’ll create the “nodes” of our graph. Then we’ll encode the edges. Let’s suppose we’re reading an input stream. The stream will first give us the number of nodes and then edges in our graph. Then we’ll read a bunch of 3-tuples line-by-line. The first two numbers will refer to the "from" node and the "to" node. The third number will represent the distance. Here’s what this stream might look like for the graph pictured above:

6
9
1 2 3
1 3 4
2 3 5
2 4 2
2 5 6
3 5 5
4 6 9
5 4 1
5 6 10

We’ll read this input stream like so:

import Control.Monad (replicateM)
import System.IO (Handle, hGetLine)

data EdgeSpec = EdgeSpec
  { fromNode :: Int
  , toNode :: Int
  , distance :: Int
  }

readInputs :: Handle -> IO (Int, [EdgeSpec])
readInputs handle = do
  numNodes <- read <$> hGetLine handle
  numEdges <- (read <$> hGetLine handle)
  edges <- replicateM numEdges (readEdge handle)
  return (numNodes, edges)

readEdge :: Handle -> IO EdgeSpec
readEdge handle = do
  input <- hGetLine handle
  let [f_s, t_s, d_s] = words input
  return $ EdgeSpec (read f_s) (read t_s) (read d_s)

Our goal will be to encode this graph in the format of FGL. In this library, every node has an integer identifier. Nodes can either be “labeled” or “unlabeled”. This label, if it exists, is separate from the integer identifier. The function we’ll use requires our nodes to have labels, but we won’t need this extra information. So we’ll use a newtype to wrap the same integer identifier.

Once we know the number of nodes, it’s actually quite easy to create them all. We'll make labels from every number from 1 up through the length. Then we represent each node by the tuple of its index and label. Let’s start a function for creating our graph:

import Data.Graph.Inductive.Graph (mkGraph)
import Data.Graph.Inductive.PatriciaTree (Gr)

…

newtype NodeLabel = NodeLabel Int 
type Distance = Int

genGraph :: (Int, [EdgeSpec]) -> Gr NodeLabel Distance
genGraph (numNodes, edgeSpecs) = mkGraph nodes edges
  where
    nodes = (\i -> (i, NodeLabel i)) 
      <$> [1..numNodes]
    edges = ...

The graph we're making uses a "Patricia Tree" encoding under the hood. We won't go into details about that. We'll just call a simple mkGraph function exposed by the library. We'll make our return value the graph type Gr parameterized by our node label type and our edge label type. As we can see, we’ll use a type synonym Distance for integers to label our edges.

For now let’s get to the business of creating our edges. The format we specified with EdgeSpec works out that we don’t have to do much work. Just as the labeled node type is a synonym for a tuple, the labeled edge is a 3-tuple. It contains the indices of the “from” node, the “to” node, and then the distance label. In this case we’ll use directed edges. We do this for every edge spec, and then we’re done!

genGraph :: (Int, [EdgeSpec]) -> Gr NodeLabel Distance
genGraph (numNodes, edgeSpecs) = mkGraph nodes edges
  where
    nodes = (\i -> (i, NodeLabel i)) 
      <$> [1..numNodes]
    edges = (\es -> (fromNode es, toNode es, distance es))
      <$> edgeSpecs

Using Graph Algorithms

Now suppose we want to solve a particular graph problem. First we’ll tackle shortest path. If we remember from above, the shortest path from node 1 to node 6 on our graph actually just goes along the top, from 1 to 2 to 4 to 6.

Graph Basic (2).png

How can we solve this from Haskell? We’ll first we’ll use our functions from above to read in the graph. Then we’ll imagine reading in two more numbers for the start and end nodes.

solveSP :: Handle -> IO ()
solveSP handle = do
  inputs <- readInputs handle
  start <- read <$> hGetLine handle
  end <- read <$> hGetLine handle
  let gr = genGraph inputs

Now with FGL we can simply make a couple library calls and we’ll get our results! We’ll use the Query.SP module, which exposes functions to find the shortest path and its length:

import Data.Graph.Inductive.Query.SP (sp, spLength)

solveSP :: Handle -> IO ()
solveSP handle = do
  inputs <- readInputs handle
  start <- read <$> hGetLine handle
  end <- read <$> hGetLine handle
  let gr = genGraph inputs
  print $ sp start end gr
  print $ spLength start end gr

We’ll get our output, which contains a representation of the path as well as the distance. Imagine “input.txt” contains our sample input above, except with two more lines for the start and end nodes “1” and “6”:

>> find-shortest-path < input.txt
[1,2,4,6]
14

We could change our file to instead go from 3 to 6, and then we’d get:

>> find-shortest-path < input2.txt
[3,5,4,6]
15

Cool!

Minimum Spanning Tree

Now let’s imagine a different problem. Suppose our nodes are internet hubs. We only want to make sure they’re all connected to each other somehow. We’re going to pick a subset of the edges that will create a “spanning tree”, connecting all our nodes. Of course, we want to do this in the cheapest way, by using the smallest total “distance” from the edges. This will be our “Minimum Spanning Tree”. First, let’s remove the directions on all the arrows. We can visualize this solution by looking at this picture, and we’ll see that we can connect our nodes at a total cost of 19.

The great news is that it’s not much more work to code this! First, we’ll adjust our graph construction a bit. To have an “undirected” graph in this scenario, we can make our arrows bi-directional like so:

genUndirectedGraph :: (Int, [EdgeSpec]) -> Gr NodeLabel Distance
genUndirectedGraph (numNodes, edgeSpecs) = mkGraph nodes edges
  where
    nodes = (\i -> (i, NodeLabel i)) 
      <$> [1..numNodes]
    edges = concatMap (\es -> 
      [(fromNode es, toNode es, distance es), (toNode es, fromNode es, distance es)])
      edgeSpecs

Besides this, all we have to do now is use the msTree function from the MST module! Then we'll get our result!

import Data.Graph.Inductive.Query.MST (msTree)

...

solveMST :: Handle -> IO ()
solveMST handle = do
  inputs <- readInputs handle
  let gr = genUndirectedGraph inputs
  print $ msTree gr

{- GHC Output
>> find-mst < “input1.txt”
[[(1,0)],[(2,3),(1,0)],[(4,2),(2,3),(1,0)],[(5,1),(4,2),(2,3),(1,0)],[(3,4),(1,0)],[(6,9),(4,2),(2,3),(1,0)]]
-}

This output is a little difficult to interpret, but it’s identical to the tree we saw above. Our output is a list of lists. Each sub-list contains a path from a node to our first node. A path list has a series of tuples, with each tuple corresponding to an edge. The first element of the tuple is the starting node, and the second is the distance.

So the first element, [(1,0)] refers to how node 1 connects to itself, by a single “edge” of distance 0 starting at 1. Then if we look at the last entry, we see node 6 connects to node 1 via a path through nodes 4 and 2, with total distance 9, 2, and 3.

Conclusion

Graph problems are ubiquitous in programming, but it can be a little tricky to get them right. It can be a definite pain to write your own graph data structure from scratch. It can be even harder to write out a full algorithm, even a well known one like Dijkstra's Algorithm. In Haskell, you can use the functional graph library to streamline this process. It has a built in format for representing the graph itself. It can be a little tedious to build up this structure. But once you have it, it’s actually very easy to solve many different common problems.

Next week, we’ll do a bit more work with FGL. We’ll explore a problem that isn’t quite as cut-and-dried as the ones we looked at here. First we'll take a more abstract problem and determine what graph we want to make for it. Then we'll solve that graph problem using FGL. So check back next week on the Monday Morning Haskell blog!.

The easiest way to bring FGL into your Haskell code is to use the Stack tool. If you’re unfamiliar with this, you should take our free Stack mini-course. You’ll learn the important step of how to bring dependencies into your code. You’ll also see the different components in a Stack program and the commands you can run to manipulate them.

If you’ve never programmed in Haskell before, you should try it! Download our Getting Started Checklist. It’ll point you towards some valuable resources in your early Haskell education.

Read More
James Bowen James Bowen

Defeating Evil with Data Structures!

In last week’s article, we used benchmarks to determine how well our code performs on certain inputs. First we used the Criterion library to get some measurements for our code. Then we were able to look at those measurements in some spiffy output. We also profiled our code to try to determine what part of our code was slowing us down.

The profiling output highlighted two functions that were taking an awful lot of time. When we analyzed them, we found they were very inefficient. In this article, we’ll resolve those problems and improve our code in a couple different ways. First, we’ll use an array rather than a list to make our value accesses faster. Then, we’ll add a cool data structure called a segment tree. This will help us to quickly get the smallest height value over a particular interval.

The code examples in this article series make good use of the Stack tool. If you’ve never used Stack before, you should check out our new FREE Stack mini-course. It’ll walk you through the basics of organizing your code, getting dependencies, and running commands.

What Went Wrong?

So first let’s take some time to remind ourselves why these functions were slowing us down a lot. Both our minimum height function and our value at index function ran in O(n) time. This means each of them could scan the entire list in the worst case. Next we observed that both of these functions will get called O(n) times. Thus our total algorithm will be O(n^2) time. The time benchmarks we took backed up this theory.

The data structures we mentioned above will help us get the values we need without doing a full scan. We'll start by substituting in an array for our list, since that is a great deal easier.

Arrays

Linked lists are very common when we’re solving functional programming problems. They have some nice properties, and work very well with recursion. However, they do not allow fast access by index. For these situations, we need to use arrays. Arrays aren't as common in Haskell as other languages, and there are a few differences.

First, Haskell arrays have two type parameters. When you make an array in Java, you say whether it’s an int array (int[]) or a string array (String[]), or whatever other type. So this is only a single parameter. Whenever we want to index into the array, we always use integers.

In Haskell, we get to choose both the type that the array stores AND the type that indexes the array. Now, the indexing type has to belong to the index (Ix) typeclass. And in this case we’ll be using Int anyways. But it’s cool to know that you have more flexibility. For instance, consider representing a matrix. In Java, we have to use an “array of arrays”. This involves a lot of awkward syntax. In Haskell, we can instead use a single array indexed by tuples of integers! We could also do something like index from 1 instead of 0 if we wanted.

So for our problem, we’ll use Array Int Int for our inner fence values instead of a normal list. We'll only need to make a few code changes though! First, we'll import a couple modules and change our type to use the array:

import Data.Array
import Data.Ix (range)

...

newtype FenceValues = FenceValues { unFenceValues :: Array Int Int }

Next, instead of using (!!) to access by index, we’ll use the specialized array index (!) operator to access them.

valueAtIndex :: FenceValues -> FenceIndex -> Int
valueAtIndex values index = (unFenceValues values) ! (unFenceIndex index)

Finally, let's improve our minimumHeight function. We’ll now use the range function on our array instead of resorting to drop and take. Note we now use right - 1 since we want to exclude the right endpoint of the interval.

where
    valsInInterval :: [(FenceIndex, Int)]
    valsInInterval = zip 
      (FenceIndex <$> intervalRange) 
      (map ((unFenceValues values) !) intervalRange)
      where
        intervalRange = range (left, right - 1)

We’ll also have to change our benchmarking code to produce arrays instead of lists:

import Data.Array(listArray)

…

randomList :: Int -> IO FenceValues
randomList n = FenceValues . mkListArray <$> 
  (sequence $ replicate n (randomRIO (1, 10000 :: Int)))
  where
    mkListArray vals = listArray (0, (length vals) - 1) vals

Both our library and our benchmark now need to use array in their build-depends section of the Cabal file. We need to make sure we add this! Once we have, we can benchmark our code again, and we’ll find it’s already sped up quite a bit!

>> stack bench --profile
Running 1 benchmarks...
Benchmark fences-benchmarks: RUNNING...
benchmarking fences tests/Size 1 Test
time                 49.33 ns   (48.98 ns .. 49.71 ns)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 49.46 ns   (49.16 ns .. 49.86 ns)
std dev              1.105 ns   (861.0 ps .. 1.638 ns)
variance introduced by outliers: 33% (moderately inflated)

benchmarking fences tests/Size 10 Test
time                 4.541 μs   (4.484 μs .. 4.594 μs)
                     0.999 R²   (0.998 R² .. 1.000 R²)
mean                 4.496 μs   (4.456 μs .. 4.531 μs)
std dev              132.0 ns   (109.6 ns .. 164.3 ns)
variance introduced by outliers: 36% (moderately inflated)

benchmarking fences tests/Size 100 Test
time                 79.81 μs   (79.21 μs .. 80.45 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 79.51 μs   (78.93 μs .. 80.39 μs)
std dev              2.396 μs   (1.853 μs .. 3.449 μs)
variance introduced by outliers: 29% (moderately inflated)

benchmarking fences tests/Size 1000 Test
time                 1.187 ms   (1.158 ms .. 1.224 ms)
                     0.995 R²   (0.992 R² .. 0.998 R²)
mean                 1.170 ms   (1.155 ms .. 1.191 ms)
std dev              56.61 μs   (48.02 μs .. 70.28 μs)
variance introduced by outliers: 37% (moderately inflated)

benchmarking fences tests/Size 10000 Test
time                 15.03 ms   (14.71 ms .. 15.32 ms)
                     0.997 R²   (0.994 R² .. 0.999 R²)
mean                 15.71 ms   (15.44 ms .. 16.03 ms)
std dev              729.7 μs   (569.3 μs .. 965.4 μs)
variance introduced by outliers: 16% (moderately inflated)

benchmarking fences tests/Size 100000 Test
time                 191.4 ms   (189.2 ms .. 193.9 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 189.3 ms   (188.2 ms .. 190.5 ms)
std dev              1.471 ms   (828.0 μs .. 1.931 ms)
variance introduced by outliers: 14% (moderately inflated)

Benchmark fences-benchmarks: FINISH

Here’s what the multiplicative factors are:

Size 1: 49.33 ns
Size 10: 4.451 μs (increased ~90x)
Size 100: 79.81 μs (increased ~18x)
Size 1000: 1.187 ms (increased ~15x)
Size 10000: 15.03 ms (increased ~13x)
Size 100000: 191.4 ms (increased ~13x)

For the later cases, increasing size by a factor 10 seems to only increase the time by a factor of 13-15. We could be forgiven for thinking we have achieved O(n log n) time already!

Better Test Cases

But something still doesn't sit right. We have to remember that the theory doesn’t quite justify our excitement here. In fact our old code was SO BAD that the NORMAL case was O(n^2). Now it seems like we may have gotten O(n log n) for the average case. But we want to prepare for the worst case if we can. Imagine our code is being used by our evil adversary:

He’ll find the worst possible case! In this situation, our code will not be so performant when the lists of input heights is sorted!

main :: IO ()
main = do
  [l1, l2, l3, l4, l5, l6] <- mapM 
    randomList [1, 10, 100, 1000, 10000, 100000]
  let l7 = sortedList
  defaultMain
    [ bgroup "fences tests" 
        ...
       , bench "Size 100000 Test" $ whnf largestRectangle l6
      , bench "Size 100000 Test (sorted)" $ whnf largestRectangle l7
      ]
    ]

...

sortedList :: FenceValues
sortedList = FenceValues $ listArray (0, 99999) [1..100000]

We’ll once again find that this last case takes a loooong time, and we’ll see a big spike in run time.

>> stack bench --profile
Running 1 benchmarks...
Benchmark fences-benchmarks: RUNNING...

…

benchmarking fences tests/Size 100000 Test (sorted)
time                 378.1 s    (355.0 s .. 388.3 s)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 384.5 s    (379.3 s .. 387.2 s)
std dev              4.532 s    (0.0 s .. 4.670 s)
variance introduced by outliers: 19% (moderately inflated)

Benchmark fences-benchmarks: FINISH

It averages more than 6 minutes per case! But this time, we’ll see the profiling output has changed. It only calls out various portions of minimumHeightIndexValue! We no longer spend a lot of time in valueAtIndex.

COST CENTRE                                          %time %alloc

minimumHeightIndexValue.valsInInterval               65.0   67.7
minimumHeightIndexValue                              22.4    0.0
minimumHeightIndexValue.valsInInterval.intervalRange 12.4   32.2

So now we have to solve this new problem by improving our calculation of the minimum.

Segment Trees

Our current approach still requires us to look at every element in our interval. Even though some of our intervals will be small, there will be a lot of these smaller calls, so the total time is still O(n^2). We need a way to find the smallest item and value on a given interval without resorting to a linear scan.

One idea would be to develop an exhaustive list of all the answers to this question right at the start. We could make a mapping from all possible intervals to the smallest index and value in the interval. But this won’t help us in the end. There are still n^2 possible intervals. So creating this data structure will still mean that our code takes O(n^2) time.

But we’re on the right track with the idea of doing some of the work before hand. We'll have to use a data structure that’s not an exhaustive listing though. Enter segment trees.

A segment tree has the same structure as a binary search tree. Instead of storing a single value though, each node corresponds to an interval. Each node will store its interval, the smallest value over that interval, and the index of that value.

The top node on the tree will refer to the interval of the whole array. It'll store the pair for the smallest value and index overall. Then it will have two children nodes. The left one will have the minimum pair over the first half of the tree, and the right one will have the second half. The next layer will break it up into quarters, and so on.

As an example, let's consider how we would determine the minimum pair starting from the first quarter point and ending at the third quarter point. We’ll do this using recursion. First, we'll ask the left subtree for the minimum pair on the interval from the quarter point to the half point. Then we’ll query the right tree for the smallest pair from the half point to the three-quarters point. Then we can take the smallest of those and return it. I won’t go into all the theory here, but it turns out that even in the worst case this operation takes O(log n) time.

Designing Our Segment Tree

There is a library called Data.SegmentTree on hackage. But I thought it would be interesting to implement this data structure from scratch. We'll compose our tree from SegmentTreeNodes. Each node is either empty, or it contains six fields. The first two refer to the interval the node spans. The next will be the minimum value and the index of that value over the interval. And then we’ll have fields for each of the children nodes of this node:

data SegmentTreeNode = ValueNode
  { fromIndex :: FenceIndex
  , toIndex :: FenceIndex
  , value :: Int
  , minIndex :: FenceIndex
  , leftChild :: SegmentTreeNode
  , rightChild :: SegmentTreeNode
  }
  | EmptyNode

We could make this Segment Tree type a lot more generic so that it isn’t restricted to our fence problem. I would encourage you to take this code and try that as an exercise!

Building the Segment Tree

Now we’ll add our preprocessing step where we’ll actually build the tree itself. This will use the same interval/tail pattern we saw before. In the base case, the interval’s span is only 1, so we make a node containing that value with empty sub-children. We’ll also add a catchall that returns an EmptyNode:

buildSegmentTree :: Array Int Int -> SegmentTreeNode
buildSegmentTree ints = buildSegmentTreeTail 
  ints 
  (FenceInterval ((FenceIndex 0), (FenceIndex (length (elems ints)))))

buildSegmentTreeTail :: Array Int Int -> FenceInterval -> SegmentTreeNode
buildSegmentTreeTail array
  (FenceInterval (wrappedFromIndex@(FenceIndex fromIndex), wrappedToIndex@(FenceIndex toIndex)))
  | fromIndex + 1 == toIndex = ValueNode 
      { fromIndex = wrappedFromIndex
      , toIndex = wrappedToIndex
      , value = array ! fromIndex
      , minIndex = wrappedFromIndex
      , leftChild = EmptyNode
      , rightChild = EmptyNode
      }
  | … missing case
  | otherwise = EmptyNode

Now our middle case will be the standard case. First we’ll divide our interval in half and make two recursive calls.

where 
  average = (fromIndex + toIndex) `quot` 2
  -- Recursive Calls
  leftChild = buildSegmentTreeTail 
    array (FenceInterval (wrappedFromIndex, (FenceIndex average)))
  rightChild = buildSegmentTreeTail 
    array (FenceInterval ((FenceIndex average), wrappedToIndex))

Next we’ll write a function that’ll extract the minimum value and index, but handle the empty node case:

-- Get minimum val and index, but account for empty case.
valFromNode :: SegmentTreeNode -> (Int, FenceIndex)
valFromNode EmptyNode = (maxBound :: Int, FenceIndex (-1))
valFromNode n@ValueNode{} = (value n, minIndex n)

Now we’ll compare the three cases for the this minimum. It’ll likely be the values from the left or the right. Otherwise it’s the current value.

leftCase = valFromNode leftChild
rightCase = valFromNode rightChild
currentCase = (array ! fromIndex, wrappedFromIndex)
(newValue, newIndex) = min (min leftCase rightCase) currentCase

Finally we’ll complete our definition by filling in the missing variables in the middle/normal case. Here’s the full function:

buildSegmentTreeTail :: Array Int Int -> FenceInterval -> SegmentTreeNode
buildSegmentTreeTail array
  (FenceInterval (wrappedFromIndex@(FenceIndex fromIndex), wrappedToIndex@(FenceIndex toIndex)))
  | fromIndex + 1 == toIndex = ValueNode 
      { fromIndex = wrappedFromIndex
      , toIndex = wrappedToIndex
      , value = array ! fromIndex
      , minIndex = wrappedFromIndex
      , leftChild = EmptyNode
      , rightChild = EmptyNode
      }
  | fromIndex < toIndex = ValueNode 
    { fromIndex = wrappedFromIndex
    , toIndex = wrappedToIndex
    , value = newValue
    , minIndex = newIndex
    , leftChild = leftChild
    , rightChild = rightChild
    }
  | otherwise = EmptyNode
    where 
      average = (fromIndex + toIndex) `quot` 2
      -- Recursive Calls
      leftChild = buildSegmentTreeTail 
        array (FenceInterval (wrappedFromIndex, (FenceIndex average)))
      rightChild = buildSegmentTreeTail 
        array (FenceInterval ((FenceIndex average), wrappedToIndex))

      -- Get minimum val and index, but account for empty case.
      valFromNode :: SegmentTreeNode -> (Int, FenceIndex)
      valFromNode EmptyNode = (maxBound :: Int, FenceIndex (-1))
      valFromNode n@ValueNode{} = (value n, minIndex n)

      leftCase = valFromNode leftChild
      rightCase = valFromNode rightChild
      currentCase = (array ! fromIndex, wrappedFromIndex)
      (newValue, newIndex) = min (min leftCase rightCase) currentCase

Finding the Minimum

Now let’s write the critical function of finding the minimum over the given interval. We’ll add our tree as another parameter. Then we’ll handle the EmptyNode case and then unwrap our values for the full case:

minimumHeightIndexValue :: FenceValues -> SegmentTreeNode -> FenceInterval -> (FenceIndex, Int)
minimumHeightIndexValue values tree 
  originalInterval@(FenceInterval (FenceIndex left, FenceIndex right)) =
  case tree of
    EmptyNode -> (maxBound :: Int, -1)
    ValueNode
      { fromIndex = FenceIndex nFromIndex
      , toIndex = FenceIndex nToIndex
      , value = nValue
      , minIndex = nMinIndex
      , leftChild = nLeftChild
      , rightChild = nRightChild} ->

Next we’ll handle the base case that we are at exactly the correct node:

| left == nFromIndex && right == nToIndex = (nValue, nMinIndex)

Next we’ll observe two cases that will need only one recursive call. If the right index is below the midway point, we recursively call to the left sub-child. And if the left index is above the midway point, we’ll call on the right side (we’ll calculate the average later).

| otherwise = if right < average 
  then minimumHeightIndexValue values nLeftChild originalInterval
  else if left >= average
    then minimumHeightIndexValue values nRightChild originalInterval

Finally we have a tricky part. If the interval does cross the halfway mark, we’ll have to divide it into two sub-intervals. Then we’ll make two recursive calls, and get their solutions. Finally, we’ll compare the two solutions and take the smaller one.

else minTuple leftResult rightResult
  where
    average = (nFromIndex + nToIndex) `quot` 2
    leftResult = minimumHeightIndexValue values nLeftChild
      (FenceInterval (FenceIndex left, FenceIndex average))
    rightResult = minimumHeightIndexValue values nRightChild
      (FenceInterval (FenceIndex average, FenceIndex right))
    minTuple :: (FenceIndex, Int) -> (FenceIndex, Int) -> (FenceIndex, Int)
    minTuple old@(_, heightOld) new@(_, heightNew) =
      if heightNew < heightOld then new else old

Here’s the full function for clarity:

minimumHeightIndexValue :: FenceValues -> SegmentTreeNode -> FenceInterval -> (FenceIndex, Int)
minimumHeightIndexValue values tree 
  originalInterval@(FenceInterval (FenceIndex left, FenceIndex right)) =
  case tree of
    EmptyNode -> (maxBound :: Int, -1)
    ValueNode
      { fromIndex = FenceIndex nFromIndex
      , toIndex = FenceIndex nToIndex
      , value = nValue
      , minIndex = nMinIndex
      , leftChild = nLeftChild
      , rightChild = nRightChild} ->
        | left == nFromIndex && right == nToIndex = (nValue, nMinIndex)
        | otherwise = if right < average 
          then minimumHeightIndexValue values nLeftChild originalInterval
          else if left >= average
            then minimumHeightIndexValue values nRightChild originalInterval
            else minTuple leftResult rightResult
          where
            average = (nFromIndex + nToIndex) `quot` 2
            leftResult = minimumHeightIndexValue values nLeftChild
              (FenceInterval (FenceIndex left, FenceIndex average))
            rightResult = minimumHeightIndexValue values nRightChild
              (FenceInterval (FenceIndex average, FenceIndex right))
          minTuple :: (FenceIndex, Int) -> (FenceIndex, Int) -> (FenceIndex, Int)
          minTuple old@(_, heightOld) new@(_, heightNew) =
            if heightNew < heightOld then new else old

Touching Up the Rest

Once we’ve accomplished this, the rest is pretty straightforward. First, we’ll build our segment tree at the beginning and pass that as a parameter to our function. Then we’ll plug in our new minimum function in place of the old one. We’ll make sure to add the tree to each recursive call as well.

largestRectangle :: FenceValues -> FenceSolution
largestRectangle values = largestRectangleAtIndices values 
  (buildSegmentTree (unFenceValues values)
  (FenceInterval (FenceIndex 0, FenceIndex (length (unFenceValues values))))

…
-- Notice the extra parameter
largestRectangleAtIndices :: FenceValues -> SegmentTreeNode -> FenceInterval -> FenceSolution
largestRectangleAtIndices
  values
  tree
…
      where
      …
      -- And down here add it to each call
      (minIndex, minValue) = minimumHeightIndexValue values tree interval
      leftCase = largestRectangleAtIndices values tree (FenceInterval (leftIndex, minIndex))
      rightCase = if minIndex + 1 == rightIndex
        then FenceSolution (maxBound :: Int)
        else largestRectangleAtIndices values tree (FenceInterval (minIndex + 1, rightIndex))

And now we can run our benchmarks again. This time, we’ll see that our code runs a great deal faster on both large cases! Success!

benchmarking fences tests/Size 100000 Test
time                 179.1 ms   (173.5 ms .. 185.9 ms)
                     0.999 R²   (0.998 R² .. 1.000 R²)
mean                 184.1 ms   (182.7 ms .. 186.1 ms)
std dev              2.218 ms   (1.197 ms .. 3.342 ms)
variance introduced by outliers: 14% (moderately inflated)

benchmarking fences tests/Size 100000 Test (sorted)
time                 238.4 ms   (227.2 ms .. 265.1 ms)
                     0.998 R²   (0.989 R² .. 1.000 R²)
mean                 243.5 ms   (237.0 ms .. 251.8 ms)
std dev              8.691 ms   (2.681 ms .. 11.83 ms)
variance introduced by outliers: 16% (moderately inflated)

Conclusion

So in these past two articles we’ve learned a whole lot. We first covered how to create benchmarks for our code using Cabal/Stack. When we ran those benchmarks, we found results took longer than we would like. We then used profiling to determine what the problematic functions were. Then we dove head-first into some data structures knowledge. We saw first hand how changing the underlying data structures of our program could improve our performance. We also learned about arrays, which are somewhat overlooked in Haskell. Then we built a segment tree from scratch and used its API to enable our program’s improvements.

This problem involved many different uses of recursion. If you want to become a better functional programmer, you’d better learn recursion. If you want a better grasp of this fundamental concept, you should check out our FREE Recursion Workbook. It has two chapters of useful information as well as 10 practice problems!

If you’ve never written Haskell before but are intrigued by the possibilities you saw in this article, you should try it out! Download our Getting Started Checklist! It’ll walk you through installing the language. It'll also point you to some cool resources for starting your Haskell education.

Finally, be sure to check out our Stack mini-course. Once you’ve mastered the Stack tool, you’ll be well on your way to making Haskell projects like a Pro!

Read More
James Bowen James Bowen

How well does it work? Profiling in Haskell

I’ve said it before, but I’ll say it again. As much as we’d like to think it’s the case, our Haskell code doesn’t work just because it compiles. This is why we have test suites. But even if it passes our test suites this doesn’t mean it works as well as it could either. Sometimes we’ll realize that the code we wrote isn’t quite performant enough, so we’ll have to make improvements.

But improving our code can sometimes feel like taking shots in the dark. You'll spend a great deal of time tweaking a certain piece. Then you'll find you haven’t actually made much of a dent in the total run time of the application. Certain operations generally take longer, like database calls, network operations, and IO. So you can often have a decent idea of where to start. But it always helps to be sure. This is where benchmarking and profiling come in. We’re going to take a specific problem and learn how we can use some Haskell tools to zero in on the problem point.

As a note, the tools we’ll use require you to be organizing your code using Stack or Cabal. If you’ve never used either of these before, you should check out our Stack Mini Course! It'll teach you the basics of creating a project with Stack. You'll also learn the primary commands to use with Stack. It’s brand new and best of all FREE! Check it out! It’s our first course of any kind, so we’re looking for feedback!

The Problem

Our overarching problem for this article will be the “largest rectangle” problem. You can actually try to solve this problem yourself on Hackerrank under the name “John and Fences”. Imagine we have a series of vertical bars with varying heights placed next to each other. We want to find the area of the largest rectangle that we can draw over these bars that doesn’t include any empty space. Here’s a visualization of one such problem and solution:

In this example, we have posts with heights [2,5,7,4,1,8]. The largest rectangle we can form has an area of 12, as we see with the highlighted squares.

This problem is pretty neat and clean to solve with Haskell, as it lends itself to a recursive solution. First let’s define a couple newtypes to illustrate our concepts for this problem. We’ll use a compiler extension to derive the Num typeclass on our index type, as this will be useful later.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
...
newtype FenceValues = FenceValues { unFenceValues :: [Int] }
newtype FenceIndex = FenceIndex { unFenceIndex :: Int }
  deriving (Eq, Num, Ord)
-- Left Index is inclusive, right index is non-inclusive 
newtype FenceInterval = FenceInterval { unFenceInterval :: (FenceIndex, FenceIndex) }
newtype FenceSolution = FenceSolution { unFenceSolution :: Int }
  deriving (Eq, Show, Ord)

Next, we’ll define our primary function. It will take our FenceValues, a list of integers, and return our solution.

largestRectangle :: FenceValues -> FenceSolution
largestRectangle values = ...

It in turn will call our recursive helper function. This function will calculate the largest rectangle over a specific interval. We can solve it recursively by using smaller and smaller intervals. We’ll start by calling it on the interval of the whole list.

largestRectangle :: FenceValues -> FenceSolution
largestRectangle values = largestRectangleAtIndices values
  (FenceInterval (FenceIndex 0, FenceIndex (length (unFenceValues values))))

largestRectangleAtIndices :: FenceValues -> FenceInterval -> FenceSolution
largestRectangleAtIndices = ...

Now, to break this into recursive cases, we need some more information first. What we need is the index i of the minimum height in this interval. One option is that we could make a rectangle spanning the whole interval with this height.

Any other "largest rectangle" won't use this particular index. So we can then divide our problem into two more cases. In the first, we'll find the largest rectangle on the interval to the left. In the second, we'll look to the right.

As your might realize, these two cases simply involve making recursive calls! Then we can easily compare their results. The only thing we need to add is a base case. Here are all these cases represented in code:

largestRectangleAtIndices :: FenceValues -> FenceInterval -> FenceSolution
largestRectangleAtIndices
  values
  interval@(FenceInterval (leftIndex, rightIndex)) = 
    -- Base Case: Checks if left + 1 >= right
    if isBaseInterval interval
      then FenceSolution (valueAtIndex values leftIndex)
      -- Compare three cases
      else max (max middleCase leftCase) rightCase
      where
      -- Find the minimum height and its index
      (minIndex, minValue) = minimumHeightIndexValue values interval
      -- Case 1: Use the minimum index
      middleCase = FenceSolution $ (intervalSize interval) * minValue
      -- Recursive call #1
      leftCase = largestRectangleAtIndices values (FenceInterval (leftIndex, minIndex))
      -- Guard against case where there is no "right" interval
      rightCase = if minIndex + 1 == rightIndex
        then FenceSolution (maxBound :: Int) -- Supply a "fake" solution that we'll ignore
        -- Recursive call #2
        else largestRectangleAtIndices values (FenceInterval (minIndex + 1, rightIndex))

And just like that, we’re actually almost finished. The only sticking point here is a few helper functions. Three of these are simple:

valueAtIndex :: FenceValues -> FenceIndex -> Int
valueAtIndex values index = (unFenceValues values) !! (unFenceIndex index)

isBaseInterval :: FenceInterval -> Bool
isBaseInterval (FenceInterval (FenceIndex left, FenceIndex right)) = left + 1 >= right

intervalSize :: FenceInterval -> Int
intervalSize (FenceInterval (FenceIndex left, FenceIndex right)) = right - left

Now we have to determine the minimum on this interval. Let’s do this in the most naive way, by scanning the whole interval with a fold.

minimumHeightIndexValue :: FenceValues -> FenceInterval -> (FenceIndex, Int)
minimumHeightIndexValue values (FenceInterval (FenceIndex left, FenceIndex right)) =
  foldl minTuple (FenceIndex (-1), maxBound :: Int) valsInInterval
  where
    valsInInterval :: [(FenceIndex, Int)]
    valsInInterval = drop left (take right (zip (FenceIndex <$> [0..]) (unFenceValues values)))
    minTuple :: (FenceIndex, Int) -> (FenceIndex, Int) -> (FenceIndex, Int)
    minTuple old@(_, heightOld) new@(_, heightNew) =
      if heightNew < heightOld then new else old

And now we’re done!

Benchmarking our Code

Now, this is a neat little algorithmic solution, but we want to know if our code is efficient. We need to know if it will scale to larger input values. We can find the answer to this question by writing benchmarks. Benchmarks are a feature we can use in conjunction with Cabal and Stack. They work a lot like test suites. But instead of proving the correctness of our code, they’ll show us how fast our code runs under various circumstances. We’ll use the Criterion library to do this. We’ll start by adding a section in our .cabal file for this benchmark:

benchmark fences-benchmarks
  type:                exitcode-stdio-1.0
  hs-source-dirs:      benchmark
  main-is:             fences-benchmark.hs
  build-depends:       base
                     , FencesExample
                     , criterion
                     , random
  default-language:    Haskell2010

Now we’ll make our file fences-benchmark.hs, make it a Main module and add a main function. We’ll generate 6 lists, increasing in size by a factor of 10 each time. Then we’ll create a benchmark group and call the bench function on each situation.

module Main where

import Criterion
import Criterion.Main (defaultMain)
import System.Random

import Lib

main :: IO ()
main = do
  [l1, l2, l3, l4, l5, l6] <- mapM 
    randomList [1, 10, 100, 1000, 10000, 100000]
  defaultMain
    [ bgroup "fences tests" 
      [ bench "Size 1 Test" $ whnf largestRectangle l1
      , bench "Size 10 Test" $ whnf largestRectangle l2
      , bench "Size 100 Test" $ whnf largestRectangle l3
      , bench "Size 1000 Test" $ whnf largestRectangle l4
      , bench "Size 10000 Test" $ whnf largestRectangle l5
      , bench "Size 100000 Test" $ whnf largestRectangle l6
      ]
    ]

-- Generate a list of a particular size
randomList :: Int -> IO FenceValues
randomList n = FenceValues <$> (sequence $ replicate n (randomRIO (1, 10000 :: Int)))

We’d normally run these benchmarks with stack bench (or cabal bench if you’re not using Stack). But we can also compile our code with the --profile flag. This will automatically create a profiling report with more information about our code. Note using profiling requires re-compiling ALL the dependencies to use profiling as well. So you don't want to switch back and forth a lot.

>> stack bench --profile
Benchmark fences-benchmarks: RUNNING...
benchmarking fences tests/Size 1 Test
time                 47.79 ns   (47.48 ns .. 48.10 ns)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 47.78 ns   (47.48 ns .. 48.24 ns)
std dev              1.163 ns   (817.2 ps .. 1.841 ns)
variance introduced by outliers: 37% (moderately inflated)

benchmarking fences tests/Size 10 Test
time                 3.324 μs   (3.297 μs .. 3.356 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 3.340 μs   (3.312 μs .. 3.368 μs)
std dev              98.52 ns   (79.65 ns .. 127.2 ns)
variance introduced by outliers: 38% (moderately inflated)

benchmarking fences tests/Size 100 Test
time                 107.3 μs   (106.3 μs .. 108.2 μs)
                     0.999 R²   (0.999 R² .. 0.999 R²)
mean                 107.2 μs   (106.3 μs .. 108.4 μs)
std dev              3.379 μs   (2.692 μs .. 4.667 μs)
variance introduced by outliers: 30% (moderately inflated)

benchmarking fences tests/Size 1000 Test
time                 8.724 ms   (8.596 ms .. 8.865 ms)
                     0.998 R²   (0.997 R² .. 0.999 R²)
mean                 8.638 ms   (8.560 ms .. 8.723 ms)
std dev              228.8 μs   (193.6 μs .. 272.8 μs)

benchmarking fences tests/Size 10000 Test
time                 909.2 ms   (899.3 ms .. 914.1 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 915.1 ms   (914.6 ms .. 915.8 ms)
std dev              620.1 μs   (136.0 as .. 664.8 μs)
variance introduced by outliers: 19% (moderately inflated)

benchmarking fences tests/Size 100000 Test
time                 103.9 s    (91.11 s .. 117.3 s)
                     0.997 R²   (0.997 R² .. 1.000 R²)
mean                 107.3 s    (103.7 s .. 109.4 s)
std dev              3.258 s    (0.0 s .. 3.702 s)
variance introduced by outliers: 19% (moderately inflated)

Benchmark fences-benchmarks: FINISH

So when we run this, we’ll find something...troubling. It takes a looong time to run the final benchmark on size 100000. On average, this case takes over 100 seconds...more than a minute and a half! We can further take note of how the average run time increases based on the size of the case. Let’s pare down the data a little bit:

Size 1: 47.78 ns
Size 10: 3.340 μs (increased ~70x)
Size 100: 107.2 μs (increased ~32x)
Size 1000: 8.638 ms (increased ~81x)
Size 10000: 915.1 ms (increased ~106x)
Size 100000: 107.3 s (increased ~117x)

Each time we increase the size of the problem by a factor of 10, the time spent increased by a factor closer to 100! This suggests our run time is O(n^2) (check out this guide if you are unfamiliar with Big-O notation). We’d like to do better.

Determining the Problem

So we want to figure out why our code isn’t performing very well. Luckily, we already profiled our benchmark!. This outputs a specific file that we can look at, called fences-benchmark.prof. It has some very interesting results:

COST CENTRE                            MODULE SRC                 %time %alloc
minimumHeightIndexValue.valsInInterval Lib    src/Lib.hs:45:5-95   69.8   99.7
valueAtIndex                           Lib    src/Lib.hs:51:1-74   29.3    0.0

We see that we have two big culprits taking a lot of time. First, there is our function that determines the minimum between a specific interval. The report is even more specific, calling out the specific offending part of the function. We spend a lot of time getting the different values for a specific interval. In second place, we have valueAtIndex. This means we also spend a lot of time getting values out of our list.

First let’s be glad we’ve factored our code well. If we had written our entire solution in one big function, we wouldn’t have any leads here. This makes it much easier for us to analyze the problem. When examining the code, we see why both of these functions could produce O(n^2) behavior.

Due to the number of recursive calls we make, we’ll call each of these functions O(n) times. Then when we call valueAtIndex, we use the (!!) operator on our linked list. This takes O(n) time. Scanning the whole interval for the minimum height has the same effect. In the worst case, we have to look at every element in the list! I’m hand waving a bit here, but that is the basic result. When we call these O(n) pieces O(n) times, we get O(n^2) time total.

Cliff Hanger Ending

We can actually solve this problem in O(n log n) time, a dramatic improvement over the current O(n^2). But we’ll have to improve our data structures to accomplish this. First, we’ll store our values so that we can go from the index to the element in sub-linear time. This is easy. Second, we have to determine the index containing the minimum element within an arbitrary interval. This is a bit trickier to do in sub-linear time. We'll need a more advanced data structure.

To find out how we solve these problems, you’ll have to wait for part 2 of this series! Come back next week to the Monday Morning Haskell blog!

As a reminder, we’ve just published a FREE mini-course on Stack. It’ll teach you the basics of laying out a project and running commands on it using the Stack tool. You should enroll in the Monday Morning Haskell Academy to sign up! This is our first course of any kind, so we would love to get some feedback! Once you know about Stack, it'll be a lot easier to try this problem out for yourself!

In addition to Stack, recursion also featured pretty heavily in our solution here. If you’ve done any amount of functional programming you’ve seen recursion in action. But if you want to solidify your knowledge, you should download our FREE Recursion Workbook! It has two chapters worth of content on recursion and it has 10 practice problems you can work through!

Never programmed in Haskell? No problem! You can download our Getting Started Checklist and it’ll help you get Haskell installed. It’ll also point you in the direction of some great learning resources as well. Take a look!

Read More
James Bowen James Bowen

Taking a Close look at Lenses

So when we first learned about creating our own data types, we went over the idea of record syntax. The idea of record syntax is pretty simple. We want to create objects with named fields. This allows us to avoid the tedium of pattern matching on objects all the time to get any kind of data out of them. Record syntax allows us to create functions that pull an individual field out of an object. Besides retrieving fields from an object, we can also create a modified object. We specify only the records we want to change.

data Task = Task
  { taskName :: String
  , taskExpectedMinutes :: Int
  , taskCompleteTime :: UTCTime }

truncateName :: Task -> Task
truncateName task = task { taskName = take 15 originalName }
  where
    originalName = taskName task

We see examples of both these ideas in this little code snippet. Notice that this isn’t at all like the syntax in a language like Java or Javascript. In Javascript, we'll write a function that has comparable functionality like this:

function truncateName(task) {
  task.taskName = task.taskName.substring(0,15);
  return task;
}

This is more in line with how most programmers think of accessor and setter fields. We put the name of the field after the object itself instead of before. Suppose we add another layer on top of our data model. We start to see ways in which record syntax can get a little bit clunky:

data Project = Project
  { projectName :: String
  , projectCurrentTask :: Task
  , projectRemainingTasks :: [Task] }

truncateCurrentTaskName :: Project -> Project
truncateCurrentTaskName project = project { projectCurrentTask = modifiedTask }
  where
    cTask = projectCurrentTask project
    modifiedTask = cTask { taskName = take 15 (taskName cTask) }

In this example we’ll find the Javascript code actually looks somewhat cleaner. Admittedly, it is performing the “simpler” operation of updating the object in place.

So what can we do in Haskell about this? Are we doomed to be using record syntax our whole lives and making odd updates to objects? Of course not! There’s a great tool that allows to get this more natural looking syntax. It also enables us to perform some cool functionality in our code. The tools are lenses and prisms. Lenses and prisms offer a different way to have getters and setters to our objects. There are a few different ways of doing lenses, but we’ll focus on using the Control.Lens library.

Lenses

Lenses are functions that take an object and “peel” off layers from the object. They allow us to access deeper underlying fields. The syntax can be a little bit confusing, so it can be hard to write our own lenses at first. Luckily, the Control.Lens.TH library has us covered there. First, by convention, we'll change our data type so that all the field names begin with underscores:

data Task = Task
  { _taskName :: String
  , _taskExpectedMinutes :: Int
  , _taskCompleteTime :: UTCTime }

data Project = Project
  { _projectName :: String
  , _projectCurrentTask :: Task
  , _projectRemainingTasks :: [Task] }

Now we can use the directive template Haskell function “makeLenses.” It will generate the getter and setter functions that our data types need:

data Task = Task
  { _taskName :: String
  , _taskExpectedMinutes :: Int
  , _taskCompleteTime :: UTCTime }

makeLenses ‘’Task

data Project = Project
  { _projectName :: String
  , _projectCurrentTask :: Task
  , _projectRemainingTasks :: [Task] }

makeLenses ‘’Project

If you didn’t read last week’s article on Data.Aeson, you might be thinking “Whoa whoa whoa stop. What is this template Haskell nonsense?” Template Haskell is a system where we can have the compiler generate boilerplate code for us. It’s useful in many situations, but there are tradeoffs.

The benefits are clear. Template Haskell allows us to avoid writing code that is very tedious and mindless to write. The drawbacks are a little more hidden. First, it has a tendency to increase our compile times a fair amount. Second, a lot of the functions we’ll end up using won’t be defined anywhere in our source code. This won’t be too much of an issue here with our lenses. The generated functions will be using the field names or some derivative of them. But it can still be frustrating for newer Haskell developers. Also, the type errors for lenses can be very confusing. This compounds the difficulty newbies might have. Even seasoned Haskell developers are often confounded by them. So template Haskell has a definite price to pay in terms of accessibility.

One thing to remember about lenses is we don't have to use template Haskell to generate them. In fact, I’ll show the definition right here of how we’d create these lens functions. Don’t sweat understanding the syntax. I’m just demonstrating that the amount of code generated isn’t particularly big:

data Task = Task
  { _taskName :: String
  , _taskExpectedMinutes :: Int
  , _taskCompleteTime :: UTCTime }

taskName :: Lens’ Task String
taskName func task@Task{ _taskName = name} =
  func name <&> \newName -> task {_taskName = newName }

taskExpectedMinutes :: Lens’ Task Int
taskExpectedMinutes func task@Task{_taskExpectedMinutes = expMinutes} =
  func expMinutes <&> \newExpMinutes -> task {_taskExpectedMinutes = newExpMinutes}

taskCompleteTime :: Lens’ Task UTCTime
taskCompleteTime func task@Task{_taskCompleteTime = completeTime} =
  func completeTime <&> \newCompleteTime -> task{_taskCompleteTime = newCompleteTime}

data Project = Project
  { _projectName :: String
  , _projectCurrentTask :: Task
  , _projectRemainingTasks :: [Task] }

projectName :: Lens’ Project String
projectName func project@Project{ _projectName = name} =
  func name <&> \newName -> project {_projectName = newName }

projectCurrentTask :: Lens’ Project Task
projectCurrentTask func project@Project{ _projectCurrentTask = cTask} =
  func cTask <&> \newTask -> project {_projectCurrentTask = newTask }

projectRemainingTasks :: Lens’ Project [Task]
projectRemainingTasks func project@Project{ _projectRemainingTasks = tasks} =
  func tasks <&> \newTasks -> project {_projectRemainingTasks = newTasks }

Writing your own lenses can be tedious. But it can also give you more granular control over what lenses your type actually exports. For instance, you might not want to make particular fields public at all, or you might want them to be readonly. This is easier when writing your own lenses. So one thing we can observe from this code is that we have a function for each of the different fields in our object. This function actually encapsulates both the getter and the setter. We’ll use one or the other depending on the usage of the Lens.

Operators

Haskell libraries can be notorious for their use of strange looking operators. Lens might be one of the biggest offenders here. But we’ll try to limit ourselves to a few of the most basic operators. These will give us a flavor for how lenses operate both in the getting and setting ways.

The first operator we’ll concern ourselves with is the “view” operator, (^.). This is a simple “get” operator that allows you to access the field of a particular object. So now we can re-write the very first code snippet to show this operator in action. It’s called the “view” operator since it is a synonym for the view function, which is how we can express it as a non-operator:

truncateName :: Task -> Task
truncateName task = task { _taskName = take 15 originalName }
  where
    originalName = task ^. taskName
    -- equivalent to: `view taskName task`

The next operator is the “set” operator, (.~). As you might expect, this allows us to return a mutated object with one or more updated fields. Let’s update the definition of our simple truncating function to use this:

truncateName' :: Task -> Task
truncateName' task = task & taskName .~ take 15 (task ^. taskName)

We can even do better than this by introducting the %~ operator. It allows us to apply an arbitrary function over our lens. In this case, we want to use the current value of the field, just with the take 15 function applied to it. We’ll use this to complete our function definition.

truncateName’’ :: Task -> Task
truncateName’’ task = task & taskName %~ take 15

Note that & itself is the reverse function application operator. In this situation, it acts as a simple precedence operator. We can use it to combine different lens operations on the same item. For instance, here’s an example where we change both the task name and the expected time:

changeTask :: Task -> Task
changeTask task = task 
  & taskName .~ “Updated Task”
  & taskExpectedMinutes .~ 30

One thing to note about lenses is that they get more powerful the deeper you nest them. It is easy to compose lenses with function composition. For instance, remember how annoying it was to truncate the current task name of a project? Well that’s a lot easier with lenses!

truncateCurrentTaskName :: Project -> Project
truncateCurrentTaskName project = project
  & projectCurrentTask . taskName %~ take 15

In this case, we could access the task’s name with currentTask.taskName. This almost looks like javascript syntax! It allows us to dive in and change the task’s name without much of a fuss!

Prisms

Now that we understand the basics of lenses, we can move one level deeper and look at prisms. Lenses allowed us to peek into the different fields within a product type. But prisms allow us to look at the different branches of a sum type. I don’t use this terminology too much on this blog so here's a quick example explaining sum vs. product types:

-- 
-- “Product” type...one constructor, many fields
data OriginalTask = OriginalTask
  { taskName :: String
  , taskExpectedMinutes :: Int
  , taskCompleteTime :: UTCTime }

-- “Sum” type...many constructors
data NewTask =
  SimpleTask String |
  HarderTask String Int |
  CompoundTask String [NewTask]

So the top example is a “product” type, with several fields and one constructor. Since we have named the fields using record syntax, we refer to it as a “distinguished” product. The bottom type is a “sum” type, since it has different constructors. We can generate prisms on such a type in a similar way to how we generate lenses for our Task type. We’ll use the makePrisms functions instead of makeLenses:

data NewTask =
  SimpleTask String |
  HarderTask String Int |
  CompoundTask String [NewTask]

makePrisms ''NewTask

Notice the difference in convention between lenses and prisms. With lenses, we give the field names underscores and then the lens names have no underscores. With prisms, the constructors look “clean” and the prism names have underscores.

Fundamentally, a prism involves exploring a possible branch of an object. Hence they may fail, and so they return Maybe values. Since these values might not be there, we access them with the ^? operator. This removes the constructor itself and extracts the values themselves from an object. It turns the fields within each object to an “undistinguished” product. This means if there is only one field we get that field, and if there are many fields, we get a tuple.

>> let a = SimpleTask "Clean"
>> let b = HarderTask "Clean Kitchen" 15
>> let c = CompoundTask “Clean House” [a,b]
>> a ^? _SimpleTask
Just “Clean”
>> b ^? _HarderTask
Just (“Clean Kitchen”, 15)
>> a ^? _HarderTask
Nothing
>> c ^? _SimpleTask
Nothing

This behavior doesn’t change if we use record syntax on each of our different types. Since we get a tuple whenever we have two or more fields, we can actually use lenses to delve further into that tuple. This offers some cool composability. Note that _1 is a lens that allows us to access the first element of a tuple. Similarly with _2 and _3 and so on.

>> b ^? _HarderTask._2
Just 15
>> c ^? _CompoundTask._1
Just “Clean House”
>> a ^? _HarderTask._1
Nothing

The best part is that we can still have “setters” over prisms, and these setters don’t even have error conditions! By default, if you try setting something and use the “wrong” branch, you’ll get the original object back out:

>> let b = HarderTask "Clean Kitchen" 15
>> b & _SimpleTask .~ "Clean Garage"
HarderTask “Clean Kitchen" 15
>> b & _HarderTask._2 .~ 30
HarderTask “Clean Kitchen" 30

Folds and Traversals Primer

The last example I’ll leave you with is a quick taste of some of the more advanced things we can do with prisms. We'll take a peek at the concept of Folds and Traversals. Prisms address one part of a structure that may or may not exist. Traversals and folds are functions that address many parts that may or may not exist.

Suppose we have a list of our NewTask items. We don’t care about the compound tasks or the basic tasks. We just want to know the total amount of time on our HarderTask items. We could define such a function like this that performs a manual pattern match:

sumHarderTasks :: [NewTask] -> Int
sumHarderTasks = foldl timeFromTask 0
  where
    timeFromTask accum (HarderTask _ time) = time + accum
    timeFromTask accum _ = accum

{- In GHCI:

>> let tasks = [SimpleTask "first", HarderTask "second" 10, SimpleTask "third", HarderTask "fourth" 15, CompoundTask [SimpleTask "last"]]
>> sumHarderTasks tasks
25

-}

But we can also do it a lot more easily with our prisms. You use folds and traversals with the traverse function. This function is so powerful it deserves its own article.

In this example, we’ll “traverse” our list of tasks. We'll pick out all the ones with HarderTask using a prism. Then we'll sum the values we get by applying the _2 lens. Awesome!

sumHarderTasks :: [NewTask] -> Int
sumHarderTasks tasks = sum (tasks ^.. traverse . _HarderTask . _2)

So if we break it down, tasks ^.. Traverse will give us the original list. Then adding the _HarderTask prism will filter it, leaving us with only tasks using the HarderTask constructor. Finally, applying the _2 lens will turn this filtered list into a list of the times on the task elements. Last, we take the sum of this list.

Conclusion

So in this article, we saw a basic introduction to the idea of lenses and prisms. We saw how to generate these functions over our types with template Haskell. We got a brief feel for some of the operators involved. We also saw how these concepts make it a lot easier for us to deal with nested structures. If you have time for a more thorough introduction to lenses, you should watch John Wiegley’s talk from BayHac! It was my primary inspiration for this article. It helped me immensely in understanding the ideas I presented here. In particular, if you want more ideas about traversals and folds, he has some super cool examples.

If you’re new to Haskell, don’t sweat all this advanced stuff! Check out our Getting Started Checklist. It'll teach you about some tools and resources to get yourself started on coding in Haskell!

Perhaps you’ve done a teensy bit of Haskell but want more practice on the fundamentals before you dive into something like lenses. If so, you should download our Recursion Workbook. It’ll walk you through the basics of recursion and higher order functions. You'll also get some practice problems to test your skills!

Read More