## Balancing Folds

There are three main ways to fold things in Haskell: from the right,
from the left, and from either side. Let’s look at the left vs right
variants first. `foldr`

works
from the right:

```
foldr (+) 0 [1,2,3]
1 + (2 + (3 + 0))
```

And `foldl`

from the
left:

```
foldl (+) 0 [1,2,3]
0 + 1) + 2) + 3 ((
```

As you’ll notice, the result of the two operations above is the same
(6; although one may take much longer than the other). In fact,
*whenever* the result of `foldr`

and
`foldl`

is
the same for a pair of arguments (in this case `+`

and `0`

), we say
that that pair forms a `Monoid`

for
some type (well, there’s some extra stuff to do with `0`

, but I only
care about associativity at the moment). In this case, the `Sum`

monoid
is formed:

```
newtype Sum a = Sum { getSum :: a }
instance Num a => Monoid (Sum a) where
mempty = Sum 0
mappend (Sum x) (Sum y) = Sum (x + y)
```

When you know that you have a monoid, you can use the `foldMap`

function: this is the third kind of fold. It says that you don’t care
which of `foldl`

or `foldr`

is used,
so the implementer of `foldMap`

can
put the parentheses wherever they want:

```
foldMap Sum [1,2,3]
1 + 2) + (3 + 0)
(0 + ((1 + 2) + 3)
0 + 1) + 2) + 3 ((
```

And we can’t tell the difference from the result. This is a pretty bare-bones introduction to folds and monoids: you won’t need to know more than that for the rest of this post, but the topic area is fascinating and deep, so don’t let me give you the impression that I’ve done anything more than scratched the surface.

# Other Ways to Fold

Quite often, we *do* care about where the parentheses go.
Take, for instance, a binary tree type, with values at the leaves:

```
data Tree a
= Empty
| Leaf a
| Tree a :*: Tree a
instance Show a =>
Show (Tree a) where
show Empty = "()"
show (Leaf x) = show x
show (l :*: r) = "(" ++ show l ++ "*" ++ show r ++ ")"
```

We can’t (well, shouldn’t) us `foldMap`

here,
because we would be able to tell the difference between different
arrangements of parentheses:

```
>>> foldMap something [1,2,3]
1*2)*(3*())) │ (()*((1*2)*3)) │ (((()*1)*2)*3)
((
───────────────┼────────────────┼───────────────1 │ ┌() │ ┌()
┌
┌┤ │ ┤ │ ┌┤2 │ │ ┌1 │ │└1
│└
┤ │ │┌┤ │ ┌┤3 │ ││└2 │ │└2
│┌
└┤ │ └┤ │ ┤3 │ └3 └() │ └
```

So we use one of the folds which lets us choose the arrangements of parentheses:

```
>>> (foldr (:*:) Empty . map Leaf) [1,2,3,4,5,6]
1*(2*(3*(4*(5*(6*()))))))
(1
┌
┌┤2
│└
┌┤3
│└
┌┤4
│└
┌┤5
│└
┌┤6
│└
┤
└()
>>> (foldl (:*:) Empty . map Leaf) [1,2,3,4,5,6]
*1)*2)*3)*4)*5)*6)
((((((()
┌()
┤1
│┌
└┤2
│┌
└┤3
│┌
└┤4
│┌
└┤5
│┌
└┤6 └
```

The issue is that neither of the trees generated are necessarily what
we want: often, we want something more *balanced*.

## TreeFold

To try and find a more balanced fold, let’s (for now) assume we’re
always going to get non-empty input. This will let us simplify the `Tree`

type a
little, to:

```
data Tree a
= Leaf a
| Tree a :*: Tree a
deriving Foldable
instance Show a =>
Show (Tree a) where
show (Leaf x) = show x
show (l :*: r) = "(" ++ show l ++ "*" ++ show r ++ ")"
```

Then, we can use Jon Fairbairn’s fold described in this email, adapted a bit for our non-empty input:

```
import Data.List.NonEmpty (NonEmpty(..))
treeFold :: (a -> a -> a) -> NonEmpty a -> a
= go
treeFold f where
:| []) = x
go (x :| b:l) = go (f a b :| pairMap l)
go (a :y:rest) = f x y : pairMap rest
pairMap (x= xs pairMap xs
```

There are two parts to this function: `pairMap`

and the `go`

helper. `pairMap`

combines adjacent elements in
the list using the combining function. As a top-level function it might
look like this:

```
:y:rest) = f x y : pairMap f rest
pairMap f (x= xs
pairMap f xs
++) ["a","b","c","d","e"]
pairMap (-- ["ab","cd","e"]
```

As you can see, it leaves any leftovers untouched at the end of the list.

The `go`

helper applies `pairMap`

repeatedly to the list until
it has only one element. This gives us much more balanced results that
`foldl`

or
`foldr`

(turn on `-XOverloadedLists`

to write non-empty lists using this syntax):

```
>>> (treeFold (:*:) . fmap Leaf) [1,2,3,4,5,6]
1*2)*(3*4))*(5*6))
(((1
┌
┌┤2
│└
┌┤3
││┌
│└┤4
│ └
┤5
│┌
└┤6
└
>>> (treeFold (:*:) . fmap Leaf) [1,2,3,4,5,6,7,8]
1*2)*(3*4))*((5*6)*(7*8)))
(((1
┌
┌┤2
│└
┌┤3
││┌
│└┤4
│ └
┤5
│ ┌
│┌┤6
││└
└┤7
│┌
└┤8 └
```

However, there are still cases where one branch will be much larger than its sibling. The fold fills a balanced binary tree from the left, but any leftover elements are put at the top level. In other words:

```
>>> (treeFold (:*:) . fmap Leaf) [1..9]
1*2)*(3*4))*((5*6)*(7*8)))*9)
((((1
┌
┌┤2
│└
┌┤3
││┌
│└┤4
│ └
┌┤5
││ ┌
││┌┤6
│││└
│└┤7
│ │┌
│ └┤8
│ └
┤9 └
```

That `9`

hanging out
on its own there is a problem.

## Typewriters and Slaloms

One observation we can make is that `pairMap`

always starts from the same
side on each iteration, like a typewriter moving from one line to the
next. This has the consequence of building up the leftovers on one side,
leaving them until the top level.

We can improve the situation slightly by going back and forth, slalom-style, so we consume leftovers on each iteration:

```
treeFold :: (a -> a -> a) -> NonEmpty a -> a
= goTo where
treeFold f
:| []) = y
goTo (y :| b : rest) = goFro (pairMap f (f a b) rest)
goTo (a :| []) = y
goFro (y :| b : rest) = goTo (pairMap (flip f) (f b a) rest)
goFro (a
= go [] where
pairMap f :b:rest) = go (y:ys) (f a b) rest
go ys y (a= z :| y : ys
go ys y [z] = y :| ys go ys y []
```

Notice that we have to flip the combining function to make sure the ordering is the same on output. For the earlier example, this solves the issue:

```
>>> (treeFold (:*:) . fmap Leaf) [1..9]
1*2)*((3*4)*(5*6)))*((7*8)*9))
(((1
┌
┌┤2
│└
┌┤3
││ ┌
││┌┤4
│││└
│└┤5
│ │┌
│ └┤6
│ └
┤7
│ ┌
│┌┤8
││└
└┤9 └
```

It does *not* build up the tree as balanced as it possibly
could, though:

```
>>> (treeFold (:*:) . fmap Leaf) [1,2,3,4,5,6]
1*2)*((3*4)*(5*6)))
((1
┌
┌┤2
│└
┤3
│ ┌
│┌┤4
││└
└┤5
│┌
└┤6 └
```

There’s four elements in the right branch, and two in the left in the above example. Three in each would be optimal.

Wait—optimal in what sense, exactly? What do we mean when we say one tree is more balanced than another? Let’s say the “balance factor” is the largest difference in size of two sibling trees:

```
balFac :: Tree a -> Integer
= fst . go where
balFac go :: Tree a -> (Integer, Integer)
Leaf _) = (0, 1)
go (:*: r) = (lb `max` rb `max` abs (rs - ls), rs + ls) where
go (l = go l
(lb,ls) = go r (rb,rs)
```

And one tree is more balanced than another if it has a smaller balance factor.

There’s effectively no limit on the balance factor for the typewriter
method: when the input is one larger than a power of two, it’ll stick
the one extra in one branch and the rest in another (as with `[1..9]`

in the example above).

For the slalom method, it looks like there’s something more interesting going on, limit-wise. I haven’t been able to verify this formally (yet), but from what I can tell, a tree of height $n$ will have at most a balance factor of the $n$th Jacobsthal number. That’s (apparently) also the number of ways to tie a tie using $n + 2$ turns.

That was just gathered from some quick experiments and oeis.org, but it seems to make sense intuitively. Jacobsthal numbers are defined like this:

```
0 = 0
j 1 = 1
j = j (n-1) + 2 * j (n-2) j n
```

So, at the top level, there’s the imbalance caused by the second-last
`pairFold`

, plus the imbalance
caused by the third-to-last. However, the third-to-last imbalance is
twice what it was at that level, because it is now working with an
already-paired-up list. Why isn’t the second last imbalance also
doubled? Because it’s counteracted by the fact that we turned around:
the imbalance is in an element that’s a leftover element. At least
that’s what my intuition is at this point.

The minimum balance factor is, of course, one. Unfortunately, to achieve that, I lost some of the properties of the previous folds:

## Lengths

Up until now, I have been avoiding taking the length of the incoming list. It would lose a lot of laziness, cause an extra traversal, and generally seems like an ugly solution. Nonetheless, it gives the most balanced results I could find so far:

```
treeFold :: (a -> a -> a) -> NonEmpty a -> a
:|xs) = go (length (x:xs)) (x:xs) where
treeFold f (x1 [y] = y
go = f (go m a) (go (n-m) b) where
go n ys = splitAt m ys
(a,b) = n `div` 2 m
```

`splitAt`

is an
inefficient operation, but if we let the left-hand call return its
unused input from the list, we can avoid it:

```
treeFold :: (a -> a -> a) -> NonEmpty a -> a
:|xs) = fst (go (length (x:xs)) (x:xs)) where
treeFold f (x1 (y:ys) = (y,ys)
go = (f l r, rs) where
go n ys = go m ys
(l,ls) = go (n-m) ls
(r,rs) = n `div` 2 m
```

Finally, you may have spotted the state monad in this last version. We can make the similarity explicit:

```
treeFold :: (a -> a -> a) -> NonEmpty a -> a
:|xs) = evalState (go (length (x:xs))) (x:xs) where
treeFold f (x1 = state (\(y:ys) -> (y,ys))
go = do
go n let m = n `div` 2
<- go m
l <- go (n-m)
r return (f l r)
```

And there you have it: three different ways to fold in a more balanced way. Perhaps surprisingly, the first is the fastest in my tests. I’d love to hear if there’s a more balanced version (which is lazy, ideally) that is just as efficient as the first implementation.

# Stable Summation

I have found two other uses for these folds other than simply
constructing more balanced binary trees. The first is summation of
floating-point numbers. If you sum floating-point numbers in the usual
way with `foldl'`

(or, indeed,
with an accumulator in an imperative language), you will see an error
growth of
$\mathcal{O}(n)$,
where
$n$
is the number of floats you’re summing.

A well-known solution to this problem is the Kahan summation algorithm. It carries with it a running compensation for accumulating errors, giving it $\mathcal{O}(1)$ error growth. There are two downsides to the algorithm: it takes four times the number of numerical operations to perform, and isn’t parallel.

For that reason, it’s often not used in practice: instead, floats are
summed *pairwise*, in a manner often referred to as cascade
summation. This is what’s used in NumPy. The error
growth isn’t quite as
good—$\mathcal{O}(\log{n})$—but
it takes the exact same number of operations as normal summation. On top
of that:

# Parallelization

Dividing a fold into roughly-equal chunks is exactly the kind of problem encountered when trying to parallelize certain algorithms. Adapting the folds above so that their work is performed in parallel is surprisingly easy:

```
splitPar :: (a -> a -> a) -> (Int -> a) -> (Int -> a) -> Int -> a
= go
splitPar f where
0 = f (l 0) (r 0)
go l r = lt `par` (rt `pseq` f lt rt)
go l r n where
= l (n-m)
lt = r m
rt = n `div` 2
m
treeFoldParallel :: (a -> a -> a) -> NonEmpty a -> a
=
treeFoldParallel f xs const (splitPar f) xs numCapabilities treeFold
```

The above will split the fold into `numCapabilities`

chunks, and perform
each one in parallel. `numCapabilities`

is a constant defined
in GHC.Conc:
it’s the number of threads which can be run simultaneously at any one
time. Alternatively, you could the function include a parameter for how
many chunks to split the computation into. You could also have the fold
adapt as it went, choosing whether or not to spark based on how many
sparks exist at any given time:

```
parseq :: a -> b -> b
=
parseq a b
runSTseq a b) <$>
(bool (par a b) (>) numSparks getNumCapabilities))
unsafeIOToST (liftA2 (
treeFoldAdaptive :: (a -> a -> a) -> a -> [a] -> a
=
treeFoldAdaptive f
Lazy.treeFold->
(\l r `parseq` (l `parseq` f l r)) r
```

Adapted from this comment by Edward Kmett. This is actually the fastest version of all the folds.

All of this is provided in a library I’ve put up on Hackage.