Balancing Folds
There are three main ways to fold things in Haskell: from the right,
from the left, and from either side. Let’s look at the left vs right
variants first. foldr
works
from the right:
foldr (+) 0 [1,2,3]
1 + (2 + (3 + 0))
And foldl
from the
left:
foldl (+) 0 [1,2,3]
0 + 1) + 2) + 3 ((
As you’ll notice, the result of the two operations above is the same
(6; although one may take much longer than the other). In fact,
whenever the result of foldr
and
foldl
is
the same for a pair of arguments (in this case +
and 0
), we say
that that pair forms a Monoid
for
some type (well, there’s some extra stuff to do with 0
, but I only
care about associativity at the moment). In this case, the Sum
monoid
is formed:
newtype Sum a = Sum { getSum :: a }
instance Num a => Monoid (Sum a) where
mempty = Sum 0
mappend (Sum x) (Sum y) = Sum (x + y)
When you know that you have a monoid, you can use the foldMap
function: this is the third kind of fold. It says that you don’t care
which of foldl
or foldr
is used,
so the implementer of foldMap
can
put the parentheses wherever they want:
foldMap Sum [1,2,3]
1 + 2) + (3 + 0)
(0 + ((1 + 2) + 3)
0 + 1) + 2) + 3 ((
And we can’t tell the difference from the result. This is a pretty bare-bones introduction to folds and monoids: you won’t need to know more than that for the rest of this post, but the topic area is fascinating and deep, so don’t let me give you the impression that I’ve done anything more than scratched the surface.
Other Ways to Fold
Quite often, we do care about where the parentheses go. Take, for instance, a binary tree type, with values at the leaves:
data Tree a
= Empty
| Leaf a
| Tree a :*: Tree a
instance Show a =>
Show (Tree a) where
show Empty = "()"
show (Leaf x) = show x
show (l :*: r) = "(" ++ show l ++ "*" ++ show r ++ ")"
We can’t (well, shouldn’t) us foldMap
here,
because we would be able to tell the difference between different
arrangements of parentheses:
>>> foldMap something [1,2,3]
1*2)*(3*())) │ (()*((1*2)*3)) │ (((()*1)*2)*3)
((
───────────────┼────────────────┼───────────────1 │ ┌() │ ┌()
┌
┌┤ │ ┤ │ ┌┤2 │ │ ┌1 │ │└1
│└
┤ │ │┌┤ │ ┌┤3 │ ││└2 │ │└2
│┌
└┤ │ └┤ │ ┤3 │ └3 └() │ └
So we use one of the folds which lets us choose the arrangements of parentheses:
>>> (foldr (:*:) Empty . map Leaf) [1,2,3,4,5,6]
1*(2*(3*(4*(5*(6*()))))))
(1
┌
┌┤2
│└
┌┤3
│└
┌┤4
│└
┌┤5
│└
┌┤6
│└
┤
└()
>>> (foldl (:*:) Empty . map Leaf) [1,2,3,4,5,6]
*1)*2)*3)*4)*5)*6)
((((((()
┌()
┤1
│┌
└┤2
│┌
└┤3
│┌
└┤4
│┌
└┤5
│┌
└┤6 └
The issue is that neither of the trees generated are necessarily what we want: often, we want something more balanced.
TreeFold
To try and find a more balanced fold, let’s (for now) assume we’re
always going to get non-empty input. This will let us simplify the Tree
type a
little, to:
data Tree a
= Leaf a
| Tree a :*: Tree a
deriving Foldable
instance Show a =>
Show (Tree a) where
show (Leaf x) = show x
show (l :*: r) = "(" ++ show l ++ "*" ++ show r ++ ")"
Then, we can use Jon Fairbairn’s fold described in this email, adapted a bit for our non-empty input:
import Data.List.NonEmpty (NonEmpty(..))
treeFold :: (a -> a -> a) -> NonEmpty a -> a
= go
treeFold f where
:| []) = x
go (x :| b:l) = go (f a b :| pairMap l)
go (a :y:rest) = f x y : pairMap rest
pairMap (x= xs pairMap xs
There are two parts to this function: pairMap
and the go
helper. pairMap
combines adjacent elements in
the list using the combining function. As a top-level function it might
look like this:
:y:rest) = f x y : pairMap f rest
pairMap f (x= xs
pairMap f xs
++) ["a","b","c","d","e"]
pairMap (-- ["ab","cd","e"]
As you can see, it leaves any leftovers untouched at the end of the list.
The go
helper applies pairMap
repeatedly to the list until
it has only one element. This gives us much more balanced results that
foldl
or
foldr
(turn on -XOverloadedLists
to write non-empty lists using this syntax):
>>> (treeFold (:*:) . fmap Leaf) [1,2,3,4,5,6]
1*2)*(3*4))*(5*6))
(((1
┌
┌┤2
│└
┌┤3
││┌
│└┤4
│ └
┤5
│┌
└┤6
└
>>> (treeFold (:*:) . fmap Leaf) [1,2,3,4,5,6,7,8]
1*2)*(3*4))*((5*6)*(7*8)))
(((1
┌
┌┤2
│└
┌┤3
││┌
│└┤4
│ └
┤5
│ ┌
│┌┤6
││└
└┤7
│┌
└┤8 └
However, there are still cases where one branch will be much larger than its sibling. The fold fills a balanced binary tree from the left, but any leftover elements are put at the top level. In other words:
>>> (treeFold (:*:) . fmap Leaf) [1..9]
1*2)*(3*4))*((5*6)*(7*8)))*9)
((((1
┌
┌┤2
│└
┌┤3
││┌
│└┤4
│ └
┌┤5
││ ┌
││┌┤6
│││└
│└┤7
│ │┌
│ └┤8
│ └
┤9 └
That 9
hanging out
on its own there is a problem.
Typewriters and Slaloms
One observation we can make is that pairMap
always starts from the same
side on each iteration, like a typewriter moving from one line to the
next. This has the consequence of building up the leftovers on one side,
leaving them until the top level.
We can improve the situation slightly by going back and forth, slalom-style, so we consume leftovers on each iteration:
treeFold :: (a -> a -> a) -> NonEmpty a -> a
= goTo where
treeFold f
:| []) = y
goTo (y :| b : rest) = goFro (pairMap f (f a b) rest)
goTo (a :| []) = y
goFro (y :| b : rest) = goTo (pairMap (flip f) (f b a) rest)
goFro (a
= go [] where
pairMap f :b:rest) = go (y:ys) (f a b) rest
go ys y (a= z :| y : ys
go ys y [z] = y :| ys go ys y []
Notice that we have to flip the combining function to make sure the ordering is the same on output. For the earlier example, this solves the issue:
>>> (treeFold (:*:) . fmap Leaf) [1..9]
1*2)*((3*4)*(5*6)))*((7*8)*9))
(((1
┌
┌┤2
│└
┌┤3
││ ┌
││┌┤4
│││└
│└┤5
│ │┌
│ └┤6
│ └
┤7
│ ┌
│┌┤8
││└
└┤9 └
It does not build up the tree as balanced as it possibly could, though:
>>> (treeFold (:*:) . fmap Leaf) [1,2,3,4,5,6]
1*2)*((3*4)*(5*6)))
((1
┌
┌┤2
│└
┤3
│ ┌
│┌┤4
││└
└┤5
│┌
└┤6 └
There’s four elements in the right branch, and two in the left in the above example. Three in each would be optimal.
Wait—optimal in what sense, exactly? What do we mean when we say one tree is more balanced than another? Let’s say the “balance factor” is the largest difference in size of two sibling trees:
balFac :: Tree a -> Integer
= fst . go where
balFac go :: Tree a -> (Integer, Integer)
Leaf _) = (0, 1)
go (:*: r) = (lb `max` rb `max` abs (rs - ls), rs + ls) where
go (l = go l
(lb,ls) = go r (rb,rs)
And one tree is more balanced than another if it has a smaller balance factor.
There’s effectively no limit on the balance factor for the typewriter
method: when the input is one larger than a power of two, it’ll stick
the one extra in one branch and the rest in another (as with [1..9]
in the example above).
For the slalom method, it looks like there’s something more interesting going on, limit-wise. I haven’t been able to verify this formally (yet), but from what I can tell, a tree of height will have at most a balance factor of the th Jacobsthal number. That’s (apparently) also the number of ways to tie a tie using turns.
That was just gathered from some quick experiments and oeis.org, but it seems to make sense intuitively. Jacobsthal numbers are defined like this:
0 = 0
j 1 = 1
j = j (n-1) + 2 * j (n-2) j n
So, at the top level, there’s the imbalance caused by the second-last
pairFold
, plus the imbalance
caused by the third-to-last. However, the third-to-last imbalance is
twice what it was at that level, because it is now working with an
already-paired-up list. Why isn’t the second last imbalance also
doubled? Because it’s counteracted by the fact that we turned around:
the imbalance is in an element that’s a leftover element. At least
that’s what my intuition is at this point.
The minimum balance factor is, of course, one. Unfortunately, to achieve that, I lost some of the properties of the previous folds:
Lengths
Up until now, I have been avoiding taking the length of the incoming list. It would lose a lot of laziness, cause an extra traversal, and generally seems like an ugly solution. Nonetheless, it gives the most balanced results I could find so far:
treeFold :: (a -> a -> a) -> NonEmpty a -> a
:|xs) = go (length (x:xs)) (x:xs) where
treeFold f (x1 [y] = y
go = f (go m a) (go (n-m) b) where
go n ys = splitAt m ys
(a,b) = n `div` 2 m
splitAt
is an
inefficient operation, but if we let the left-hand call return its
unused input from the list, we can avoid it:
treeFold :: (a -> a -> a) -> NonEmpty a -> a
:|xs) = fst (go (length (x:xs)) (x:xs)) where
treeFold f (x1 (y:ys) = (y,ys)
go = (f l r, rs) where
go n ys = go m ys
(l,ls) = go (n-m) ls
(r,rs) = n `div` 2 m
Finally, you may have spotted the state monad in this last version. We can make the similarity explicit:
treeFold :: (a -> a -> a) -> NonEmpty a -> a
:|xs) = evalState (go (length (x:xs))) (x:xs) where
treeFold f (x1 = state (\(y:ys) -> (y,ys))
go = do
go n let m = n `div` 2
<- go m
l <- go (n-m)
r return (f l r)
And there you have it: three different ways to fold in a more balanced way. Perhaps surprisingly, the first is the fastest in my tests. I’d love to hear if there’s a more balanced version (which is lazy, ideally) that is just as efficient as the first implementation.
Stable Summation
I have found two other uses for these folds other than simply
constructing more balanced binary trees. The first is summation of
floating-point numbers. If you sum floating-point numbers in the usual
way with foldl'
(or, indeed,
with an accumulator in an imperative language), you will see an error
growth of
,
where
is the number of floats you’re summing.
A well-known solution to this problem is the Kahan summation algorithm. It carries with it a running compensation for accumulating errors, giving it error growth. There are two downsides to the algorithm: it takes four times the number of numerical operations to perform, and isn’t parallel.
For that reason, it’s often not used in practice: instead, floats are summed pairwise, in a manner often referred to as cascade summation. This is what’s used in NumPy. The error growth isn’t quite as good——but it takes the exact same number of operations as normal summation. On top of that:
Parallelization
Dividing a fold into roughly-equal chunks is exactly the kind of problem encountered when trying to parallelize certain algorithms. Adapting the folds above so that their work is performed in parallel is surprisingly easy:
splitPar :: (a -> a -> a) -> (Int -> a) -> (Int -> a) -> Int -> a
= go
splitPar f where
0 = f (l 0) (r 0)
go l r = lt `par` (rt `pseq` f lt rt)
go l r n where
= l (n-m)
lt = r m
rt = n `div` 2
m
treeFoldParallel :: (a -> a -> a) -> NonEmpty a -> a
=
treeFoldParallel f xs const (splitPar f) xs numCapabilities treeFold
The above will split the fold into numCapabilities
chunks, and perform
each one in parallel. numCapabilities
is a constant defined
in GHC.Conc:
it’s the number of threads which can be run simultaneously at any one
time. Alternatively, you could the function include a parameter for how
many chunks to split the computation into. You could also have the fold
adapt as it went, choosing whether or not to spark based on how many
sparks exist at any given time:
parseq :: a -> b -> b
=
parseq a b
runSTseq a b) <$>
(bool (par a b) (>) numSparks getNumCapabilities))
unsafeIOToST (liftA2 (
treeFoldAdaptive :: (a -> a -> a) -> a -> [a] -> a
=
treeFoldAdaptive f
Lazy.treeFold->
(\l r `parseq` (l `parseq` f l r)) r
Adapted from this comment by Edward Kmett. This is actually the fastest version of all the folds.
All of this is provided in a library I’ve put up on Hackage.