Running Training Iterations


In our last article we built a simple Tensor Flow model to perform Q-Learning on our brain. This week, we'll build out the rest of the code we need to run iterations on this model. This will train it to perform better and make more intelligent decisions.

The machine learning code for this project is in a separate repository from the game code. Check out MazeLearner to follow along. Everything for this article is on the basic-trainer branch.

To learn more about Haskell and AI, make sure to read our Haskell and AI Series!

Iterating on the Model

First let's recall what our Tensor Flow model looks like:

data Model = Model
  { weightsT :: Variable Float
  , iterateWorldStep :: TensorData Float -> Session (Vector Float)
  , trainStep :: TensorData Float -> TensorData Float -> Session ()

We need to think about how we're going to use the last two functions of it. We want to iterate on and make updates to the weights. Across the different iterations, there's certain information we need to track. The first value we'll track is the list of "rewards" from each iteration (this will be more clear in the next section). Then we'll also track the number of wins we get in the iteration.

To track these, we'll use the State monad, run on top the the Session.

runAllIterations :: Model -> World -> StateT ([Float], Int) Session ()

We'll also want a function to run a single iteration. This, in turn, will have its own state information. It will track the World state of the game it's playing. It will also track sum of the accumulated reward values from the moves in that game. Since we'll run it from our function above, it will have a nested StateT type. It will ultimately return a boolean value indicating if we have won the game. We'll define the details in the next section:

runWorldIteration :: Model ->
  StateT (World, Float) (StateT ([Float], Int) Session) Bool

We can now start by filling out our function for running all the iterations. Supposing we'll perform 1000 iterations, we'll make a loop for each iteration. We can start each loop by running the world iteration function on the current model.

runAllIterations :: Model -> World -> StateT ([Float], Int) Session ()
runAllIterations model initialWorld = do
  let numIterations = 1000
  void $ forM [1..numIterations] $ \i -> do
    (wonGame, (_, finalReward)) <-
      runStateT (runWorldIteration model world)

And now the rest is a simple matter of using our the results to update the existing state:

runAllIterations :: Model -> World -> StateT ([Float], Int) Session ()
runAllIterations model initialWorld = do
  let numIterations = 2000
  forM [1..numIterations] $ \i -> do
    (wonGame, (_, finalReward)) <-
      runStateT (runWorldIteration model) (initialWorld, 0.0)
    (prevRewards, prevWinCount) <- get
    let newRewards = finalReward : prevRewards
    let newWinCount = if wonGame
          then prevWinCount + 1
          else prevWinCount
    put (newRewards, newWinCount)

Running a Single Iteration

Now let's delve into the process of a single iteration. Broadly speaking, we have four goals.

  1. Take the current world and serialize it. Pass it through the iterateStep to get the move our model would make in this world.
  2. Apply this move, getting the "next" world state.
  3. Determine the scores for our moves in this next world. Apply the given reward as the score for the best of these moves.
  4. Use this result to compare against our original moves. Feed it into the training step and update our weights.

Let's start with steps 1 and 2. We'll get the vector representation of the current world. Then we need to encode it as TensorData so we can pass it to an input feed. Next we run our model's iterate step and get our output move. Then we can use that to advance the world state using stepWorld and updateEnvironment.

  :: Model
  -> StateT (World, Float) (StateT ([Float], Int) Session) Bool
runWorldIteration model = do
  -- Serialize the world
  (prevWorld :: World, prevReward) <- get
  let (inputWorldVector :: TensorData Float) =
        encodeTensorData (Shape [1, 8]) (vectorizeWorld prevWorld)

  -- Run our model to get the output vector and de-serialize it
  -- Lift twice to get into the Session monad
  (currentMove :: Vector Float) <- lift $ lift $
    (iterateWorldStep model) inputWorldVector
  let newMove = moveFromOutput currentMove

  -- Get the next world state
  let nextWorld = updateEnvironment (stepWorld newMove prevWorld)

Now we need to perform the Q-Learning step. We'll start by repeating the process in our new world state and getting the next vector of move scores:

runWorldIteration model = do
  let nextWorld = updateEnvironment (stepWorld newMove prevWorld)

  let nextWorldVector =
        encodeTensorData (Shape [1, 8]) (vectorizeWorld nextWorld)

  (nextMoveVector :: Vector Float) <- lift $ lift $
    (iterateWorldStep model) nextWorldVector

Now it gets a little tricky. We want to examine if the game is over after our last move. If we won, we'll get a reward of 1.0. If we lost, we'll get a reward of -1.0. Otherwise, there's no reward. While we figure out this reward value, we can also determine our final monadic action. We could return a boolean value if the game is over, or recursively iterate again:

runWorldIteration model = do
  let nextWorld = ...
  (nextMoveVector :: Vector Float) <- ...
  let (newReward, containuationAction) = case worldResult nextWorld of
        GameInProgress -> (0.0, runWorldIteration model)
        GameWon -> (1.0, return True)
        GameLost -> (-1.0, return False)

Now we'll look at the vector for our next move and replace one of its values. We'll find the maximum score, and replace it with a value that factors in the actual reward we get from the game. This is how we insert "truth" into our training process and how we'll actually learn good reward values.

import qualified Data.Vector as V

runWorldIteration model = do
  let nextWorld = ...
  (nextMoveVector :: Vector Float) <- ...
  let (newReward, containuationAction) = ...
  let (bestNextMoveIndex, maxScore) =
        (V.maxIndex nextMoveVector, V.maximum nextMoveVector)
  let (targetActionValues :: Vector Float) = nextMoveVector V.//
        [(bestNextMoveIndex, newReward + (0.99 * maxScore))]
  let targetActionData =
        encodeTensorData (Shape [10, 1]) targetActionValues

Then we'll encode this new vector as the second input to our training step. We'll still use the nextWorldVector as the first input. We conclude by updating our state variables to have their new values. Then we run the continuation action we got earlier.

runWorldIteration model = do
  let nextWorld = ...
  (nextMoveVector :: Vector Float) <- ...
  let targetActionData = ...

  -- Run training to alter the weights
  lift $ lift $ (trainStep model) nextWorldVector targetActionData
  put (nextWorld, prevReward + newReward)

Tying It Together

Now to make this code run, we need a little bit of code to tie it together. We'll make a Session action to train our game. It will output the final weights of our model.

trainGame :: World -> Session (Vector Float)
trainGame w = do
  model <- buildModel
  (finalReward, finalWinCount) <-
    execStateT (runAllIterations model w) ([], 0)
  run (readValue $ weightsT model)

Then we can run this from IO using runSession.

playGameTraining :: World -> IO (Vector Float)
playGameTraining w = runSession (trainGame w)

Last of all, we can run this on any World we like by first loading it from a file. For our first examples, we'll use a smaller 10x10 grid with 2 enemies and 1 drill powerup.

main :: IO ()
main = do
  world <- loadWorldFromFile "training_games/"
  finalWeights <- playGameTraining world
  print finalWeights


We've now got the basics down for making our Tensor Flow program work. Come back next week where we'll take a more careful look at how it's performing. We'll see if the AI from this process is actually any good or if there are tweaks we need to make to the learning process.

And make sure to download our Haskell Tensor Flow Guide! This library is difficult to use. There are a lot of secondary dependencies for it. So don't go in trying to use it blind!