Patterns for tracking state in recursive Haskell code

design-patternsfunctional programminghaskellrecursion

A frequent pattern in my Haskell code is element-wise recursion for transformation of a list with some carried state generated using the data in the list. Usually, this looks something like this:

doSomething :: (SomeA a, SomeB b) => [a] -> [b]
doSomething xs = doSomethingWithState xs []
    where
        doSomethingWithState [] _ = []
        doSomethingWithState (x:xs) state
            | someTest x state = someChange x : (doSomethingWithState xs newState)
            | otherwise = someOtherChange x : (doSomethingWithState xs newState)

For example, say I wanted to count how many of each element appears in a list, turning something like [1, 3, 3, 3, 4, 7, 7, 8, 8, 9] into [(9,1),(8,2),(7,2),(4,1),(3,3),(1,1)].

I'd probably do something like the following:

import Data.List
import Data.Maybe

counts :: (Eq a) => [a] -> [(a, Int)]
counts xs = countsWithState xs []
    where
        countsWithState [] state = state               -- End of list, stop here
        countsWithState (x:xs) state = countsWithState xs $ transformState state x
        transformState state x                         -- To get the state...
            | isNothing $ lookup x state = (x, 1) : state -- Add new elem if new
            | otherwise = incrElem x state             -- Increment elem if not
        incrElem x [] = []                             -- Should never be reached
        incrElem x ((index, value):elems)              -- Searching through list...
            | index == x = (index, (value+1)) : elems  -- Increment if found
            | otherwise = (index, value) : incrElem x elems -- Try next if not

In a much simpler but very similar example, if I were trying to keep the running average of all elements in a list, transforming something like [1, 7, 4, 18, 7, 1, 8, 2, 8, 6, 18, 12] into [1.0, 4.0, 4.0, 7.5, 7.4, 6.33..., 5.57..., 6.0, 6.22..., 6.2, 7.27..., 7.66...] where every element in the output list is the average of that element and all previous elements in the input list, I might do something like this:

runningAvg :: (Fractional a) => [a] -> [a]
runningAvg xs = runningAvgWithState xs 0 1
    where
        runningAvgWithState [] _ _ = []
        runningAvgWithState (x:xs) currentSum currentElems
            = (currentSum + x) / currentElems
            : runningAvgWithState xs (currentSum + x) (currentElems + 1)

Notice the pattern is the same. Take a recursive function of a list, define it in terms of a hidden modified version with added state, and with each round transform the state and output computed results as necessary. This pattern emerges all the time in my Haskell code.

Is there a more natural way of implementing this sort of behavior, without a more complicated xWithState function running the show and adding unnecessary verbosity and complexity?

Best Answer

Your doSomething is more or less mapAccumL from Data.List, except you've thrown away the accumulator (i.e., state) at the end. That is, you could write it as:

doSomething :: [a] -> [b]
doSomething = snd . mapAccumL step []
  where step state x = (newState, newX)
          where newX | someTest x state = someChange x
                     | otherwise        = someOtherChange x
                newState = state

In particular, the running average might look like:

runningAvg :: (Fractional a) => [a] -> [a]
runningAvg = snd . mapAccumL step (0,0)
  where step (total, n) x = ((total', n'), total' / fromIntegral n')
          where total' = total + x
                n' = n + 1

Your counting example is somewhat different, since it doesn't produce a list of the same "shape" as the input. Instead, the "state" is the set of counts, and you return the final "state" (final set of counts) at the end. This is just a foldl, as others have noted.

count :: (Eq a) => [a] -> [(a, Int)]
count = foldl step []
  where step cnts x = bumpCount x cnts

which would be straightforward if it wasn't for the fact that maintaining a set of counts as an association list makes bumpCount kind of gross. @amon's version of bumpCount (i.e., count') seems fine, or you could write it as the following fold:

bumpCount :: (Eq a) => a -> [(a, Int)] -> [(a, Int)]
bumpCount x = foldr step [(x,1)] . tails
  where step ((y,n):rest) acc | x == y    = (y,n+1) : rest
                              | otherwise = (y,n)   : acc
        step [] acc = acc

By the way, this bumpCount fold is actually pretty incredible. Despite being a "fold", it stops searching the list once it finds and bumps a matching count.