Hugo's Blog

haskell

Demystifying the state monad

Introduction

Haskell, for some the digital equivalent of a 'garden of eden'-esque waterfall under a beautiful sunset; for others a source of pain and stress. This pain and stress is for beginners often largest when they have learned a little bit about monads like IO. They feel rather confident working with those, perhaps even with the Maybe monad. However, it is then when the State monad crosses their path when they lose hope. In this post I will be rediscovering and implementing the state monad in hopes of conveying its elegance and surprising simplicity. A small spoiler, a state monad can't do anything that plain functions can't, it just does it more elegantly.

A basic example: the counter

Consider the following bit of simple haskell code:

getNextNr :: Int -> (Int, String)
getNextNr i = (i+1, show i)

main :: IO ()
main = let 
    initial :: Int
    initial = 0

    functions :: [Int -> (Int, String)]
    functions = replicate 5 getNextNr

    chain :: [Int -> (Int, String)] -> Int -> (Int, [String])
    chain [] = error "No empty lists plz"
    chain [f] = \i -> let (ni, v) = f i in (ni, [v])
    chain (f:fs) = \i -> let (ni, v) = f i
                             (fi, vs) = chain fs ni
                             in (fi, v:vs)

    (_, numbers) = chain functions initial
    in print numbers

What this program does, is figure out a way to work with the getNextNr function. Which does not only convert the number you give it to a string, but also increments that number. We want to execute this function a number of times and collect the strings it outputs, but also chain the resulting integer back the input of the next call. The output of this piece of code is unsurprisingly:

["0","1","2","3","4"]

I hope you agree that this code snippet is not particularly elegant nor easy to read.

Our goal: abstracting state

The reason that the previous code snippet is not very elegant comes mostly from the fact that we constantly have to deal with the fact that it returns more than one thing. After all, that is the only way it can have more than one effect since haskell is a pure language. However, that does not mean that we cannot abstract the action of keeping track of this counter. To shift the wording, this counter can be seen as a form of state.

Abstractly, we can see the pattern:

a -> (a, b)

where a represents the type of the state, and b the resulting value of our entire function. From this, we can create a datatype to begin our abstraction journey:

newtype State a b = State (a -> (a, b))

Keep in mind that a newtype is pretty much equivalent to a datatype but it allows the compiler to be a bit smarter and faster.

If you are part of the frustrated haskell programmer that this post is targeted towards, then you are probably asking, "alright, but how do I take the b out of State a b". Good question, let's create a function for that:

runState :: State a b -> a -> (a, b)
runState (State f) = f

This function will, given a state function and an initial state return a final state and the resulting b. If we don't care about the final state we could spoil the users of our code with a special version:

execState :: State a b -> a -> b
execState (State f) = snd . f

This is all fine and dandy but haven't really created anything interesting. Let's see if we can jump ahead in time and wonder what our getNextNr function might look like using the state monad:

getNextNr :: State Int String
getNextNr = do
    -- Get the current state (our counter)
    current <- get
    -- Update the state 
    put (current + 1)
    -- Produce the resulting value
    return (show current)

Unfortunately, this does not yet compile because our custom State type does not implement the Monad typeclass yet (which first requires Functor and Applicative). The functions get and put also do not exist yet.

Making State a functor

A functor is a type that implements the Functor typeclass. This typeclass contains only the function fmap. As explained in C# can be functional pt1: The Maybe Monad, fmap just allows users to apply a function over the type contained in the monad. In our case this is the resulting type b:

instance Functor (State a) where
    fmap f state = State $ \s -> let (ns, res) = runState state s in (ns, f res)

As you can see we simply 'hijack' the final part of the state function by applying the transformation contained in f over the resulting value res. The resulting state ns is untouched. We could now make a new version of getNextNr that is very exiting:

jacobDubbel :: State Int (String, String)
jacobDubbel = fmap (\x -> (x,x)) getNextNr

The function duplicates the resulting value into a tuple without having to unpack the State object.

Making State an applicative

If we want our State type to be used for constructing other datatypes we need to implement applicative:

instance Applicative (State a) where
    sf <*> sarg = State $ \s -> let 
        -- extract the function contained in sf
        (ns, f) = runState sf s
        -- extract the argument contained in sarg (using the new state ns)
        (nns, arg) = runState sarg ns 
        -- apply the argument to the function and return the final state
        in (nns, f arg)


    -- pure is the same thing a return. 
    -- we just set it as the resulting object without touching the state
    pure x = State $ \s -> (s, x)

This new instance allows us to do the following:

data ThreeStrings = ThreeStrings String String String

tr :: State Int ThreeStrings
tr = ThreeStrings <$> getNextNr <*> getNextNr <*> getNextNr

This really shows how we can now compose multiple calls to a State function to build bigger things without having to worry about passing on the counter every time.

Making State a monad

The most powerful of them all. Making the state a monad will give us complete freedom on how to chain multiple operations as well as unlock that sweet sweet do notation that we still need for getNextNr to work.

instance Monad (State a) where
    -- We execute the state on the lhs, then use its result to gain the state to run after and run it with the updated state'
    (State sfunc) >>= f = State $ \state -> let (state', x) = sfunc state in runState (f x) state'

    -- Don't ask why this is needed
    return = pure

Wrapping up, get and put

The two last missing pieces in the puzzle are the get and put functions. The get function should give us the current state as is. We can do this by duplicating the state to the result value. This means that the state itself becomes unchanged but we can now get the state via the bind operation.

get :: State a a
get = State $ \state -> (state, state)

Lastly the put operation is very similar in the sense that the effect is exactly reversed. We don't care about the result value but we just want to update the state regardless of its old value.

put :: a -> State a ()
put newstate = State $ \_ -> (newstate, ())

With all these changes in place, and our updated version of getNextNr we can recreate the initial program way more elegantly:

main :: IO ()
main = let
    initial :: Int
    initial = 0

    functions :: [State Int String]
    functions = replicate 5 getNextNr

    chained :: State Int [String]
    chained = sequence functions

    numbers = execState chained initial
    in print numbers

For a complete picture let me reinvent sequence as well. This functional normally operates on any Traversable but let's keep it simply and just do lists for now:

listSequence :: Monad m => [m a] -> m [a]
listSequence [] = return []
listSequence (x:xs) = (:) <$> x <*> listSequence xs

A different example: managing game state with RNG

I have to admit that the example up until now may seem a bit contrived. I did this to keep it as simple as possible. However, I owe you a slightly more useful example.

Consider that we are making a game. We have a GameState type and the main update loop:

data GameState = GameState { seed :: Int, score :: Int, playerX :: Int, playerY :: Int }

-- currently a very boring game
update :: GameState -> GameState
update = id

If we then want to create a function that can generate a random number for us using a magical hash function (I'm like haskell, lazy):

hash :: Int -> Int
hash = undefined

randomIntYuck :: GameState -> (GameState, Int)
randomIntYuck gs@GameState { seed = seed } = let ns = hash seed in (gs {seed=ns}, ns)

Hopefully the type signature rings a bell by now! We can use our state monad!

randomInt :: State GameState Int
randomInt = get >>= \gs@GameState {seed=seed} -> let ns = hash seed in put (gs {seed=ns}) >> pure ns

randomInt' :: State GameState Int
randomInt' = do
    gs@GameState {seed=seed} <- get
    let ns = hash seed
    put (gs {seed = ns})
    return ns

I have included two versions that are functionally identical. The first I consider idiomatic haskell while the second might be a bit more friendly to beginner eyes.

Now we can elegantly make a function to teleport the player to a random location. Note that this operation doesn't need to produce a result, it just modifies the state. Hence we use unit () as the resulting type.

teleportPlayer :: State GameState ()
teleportPlayer = do
    newx <- randomInt
    newy <- randomInt
    gs <- get
    put gs { playerX = newx, playerY = newy }
    return ()

Now we can change our update loop to teleport the player each frame.

update :: GameState -> GameState
update = fst . runState teleportPlayer

If we want to extend our game with more functionality we are able to compose a lot more state operations easily!

updateTime :: State GameState ()
updateTime = ...

update :: GameState -> GameState
update = fst . runState (teleportPlayer >> updateTime)

Etc. etc. etc.
last modified: Dec 04, 2021 at 01:23