Deriving a Linear-Time Applicative Traversal of a Rose Tree

Posted on May 28, 2019
Part 7 of a 10-part series on Breadth-First Traversals
Tags:

The Story so Far

Currently, we have several different ways to enumerate a tree in breadth-first order. The typical solution (which is the usual recommended approach in imperative programming as well) uses a queue, as described by Okasaki (2000). If we take the simplest possible queue (a list), we get a quadratic-time algorithm, with an albeit simple implementation. The next simplest version is to use a banker’s queue (which is just a pair of lists). From this version, if we inline and apply identities like the following:

foldr f b . reverse = foldl (flip f) b

We’ll get to the following definition:

bfe :: Forest a -> [a]
bfe ts = foldr f b ts []
  where
    f (Node x xs) fw bw = x : fw (xs : bw)

    b [] = []
    b qs = foldl (foldr f) b qs []

We can get from this function to others (like one which uses a corecursive queue, and so on) through a similar derivation. I might some day write a post on each derivation, starting from the simple version and demonstrating how to get to the more efficient at each step.

For today, though, I’m interested in the traversal of a rose tree. Traversal, here, of course, is in the applicative sense.

Thus far, I’ve managed to write linear-time traversals, but they’ve been unsatisfying. They work by enumerating the tree, traversing the effectful function over the list, and then rebuilding the tree. Since each of those steps only takes linear time, the whole thing is indeed a linear-time traversal, but I hadn’t been able to fuse away the intermediate step.

Phases

The template for the algorithm I want comes from the Phases applicative (Easterly 2019):

data Phases f a where
  Lift   :: f a -> Phases f a
  (:<*>) :: f (a -> b) -> Phases f a -> Phases f b

We can use it to write a breadth-first traversal like so:

bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (traverse go xs))

The key component that makes this work is that it combines applicative effects in parallel:

instance Functor f => Functor (Phases f) where
    fmap f (Lift x) = Lift (fmap f x)
    fmap f (fs :<*> xs) = fmap (f.) fs :<*> xs
    
instance Applicative f => Applicative (Phases f) where
    pure = Lift . pure
    Lift fs      <*> Lift xs      = Lift (fs <*> xs)
    (fs :<*> gs) <*> Lift xs      = liftA2 flip fs xs :<*> gs
    Lift fs      <*> (xs :<*> ys) = liftA2 (.)  fs xs :<*> ys
    (fs :<*> gs) <*> (xs :<*> ys) = liftA2 c    fs xs :<*> liftA2 (,) gs ys
      where
        c f g ~(x,y) = f x (g y)

We’re also using the following helper functions:

runPhases :: Applicative f => Phases f a -> f a
runPhases (Lift x) = x
runPhases (fs :<*> xs) = fs <*> runPhases xs

later :: Applicative f => Phases f a -> Phases f a
later = (:<*>) (pure id)

The problem is that it’s quadratic: the traverse in:

go (Node x xs) = liftA2 Node (Lift (f x)) (later (traverse go xs))

Hides some expensive calls to <*>.

A Roadmap for Optimisation

The problem with the Phases traversal is actually analogous to another function for enumeration: levels from Gibbons (2015).

levels :: Tree a -> [[a]]
levels (Node x xs) = [x] : foldr lzw [] (map levels xs)
  where
    lzw [] ys = ys
    lzw xs [] = xs
    lzw (x:xs) (y:ys) = (x ++ y) : lzw xs ys

lzw takes the place of <*> here, but the overall issue is the same: we’re zipping at every point, making the whole thing quadratic.

However, from the above function we can derive a linear time enumeration. It looks like this:

levels :: Tree a -> [[a]]
levels ts = f ts []
  where
    f (Node x xs) (q:qs) = (x:q) : foldr f qs xs
    f (Node x xs) []     = [x]   : foldr f [] xs

Our objective is clear, then: try to derive the linear-time implementation of bft from the quadratic, in a way analogous to the above two functions. This is actually relatively straightforward once the target is clear: the rest of this post is devoted to the derivation.

Derivation

First, we start off with the original bft.

bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (traverse go xs))
Inline traverse.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (go' xs))
    go' = foldr (liftA2 (:) . go) (pure [])
Factor out go''.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (go' xs))
    go' = foldr go'' (pure [])
    go'' (Node x xs) ys = liftA2 (:) (liftA2 Node (Lift (f x)) (later (go' xs))) ys

Inline go' (and rename go'' to go')

bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
    go' (Node x xs) ys = liftA2 (:) (liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))) ys

Definition of liftA2

bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
    go' (Node x xs) ys = liftA2 (:) (fmap Node (f x) :<*> (foldr go' (pure []) xs)) ys

Definition of liftA2 (pattern-matching on ys)

bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
    go' (Node x xs) (Lift ys)    = fmap (((:).) . Node) (f x) :<*> (foldr go' (pure []) xs) <*> Lift ys
    go' (Node x xs) (ys :<*> zs) = fmap (((:).) . Node) (f x) :<*> (foldr go' (pure []) xs) <*> ys :<*> zs
Definition of <*>.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
    go' (Node x xs) (Lift ys)    = liftA2 flip (fmap (((:).) . Node) (f x)) ys :<*> foldr go' (pure []) xs
    go' (Node x xs) (ys :<*> zs) = liftA2 c (fmap (((:).) . Node) (f x)) ys :<*> liftA2 (,) (foldr go' (pure []) xs) zs
      where
        c f g ~(x,y) = f x (g y)

Fuse liftA2 with fmap

bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = runPhases . go
  where
    go (Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
    go' (Node x xs) (Lift ys)    = liftA2 (flip . (((:).) . Node)) (f x) ys :<*> foldr go' (pure []) xs
    go' (Node x xs) (ys :<*> zs) = liftA2 (c . (((:).) . Node)) (f x) ys :<*> liftA2 (,) (foldr go' (pure []) xs) zs
      where
        c f g ~(x,y) = f x (g y)
Beta-reduction.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = go
  where
    go (Node x xs) = liftA2 Node (f x) (runPhases (foldr go' (pure []) xs))
    
    go' (Node x xs) (Lift ys)    = liftA2 (\y zs ys -> Node y ys : zs) (f x) ys :<*> foldr go' (pure []) xs
    go' (Node x xs) (ys :<*> zs) = liftA2 c (f x) ys :<*> liftA2 (,) (foldr go' (pure []) xs) zs
      where
        c y g ~(ys,z) = Node y ys : g z

At this point, we actually hit a wall: the expression

liftA2 (,) (foldr go' (pure []) xs) zs

Is what makes the whole thing quadratic. We need to find a way to thread that liftA2 along with the fold to get it to linear. This is the only real trick in the derivation: I’ll use polymorphic recursion to avoid the extra zip.

bft :: forall f a b. Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = go
  where
    go (Node x xs) = liftA2 (\y (ys,_) -> Node y ys) (f x) (runPhases (foldr go' (pure ([],())) xs))
    
    go' :: forall c. Tree a -> Phases f ([Tree b], c) -> Phases f ([Tree b], c)
    go' (Node x xs) ys@(Lift _)  = fmap (\y -> first (pure . Node y)) (f x) :<*> foldr go' ys xs
    go' (Node x xs) (ys :<*> zs) = liftA2 c (f x) ys :<*> foldr go' (fmap ((,) []) zs) xs
      where
        c y g ~(ys,z) = first (Node y ys:) (g z)

And that’s it!

Avoiding Maps

We can finally write a slightly different version that avoids some unnecessary fmaps by basing Phases on liftA2 rather than <*>.

data Levels f a where
  Now   :: a -> Levels f a
  Later :: (a -> b -> c) -> f a -> Levels f b -> Levels f c

instance Functor f => Functor (Levels f) where
    fmap f (Now x) = Now (f x)
    fmap f (Later c xs ys) = Later ((f.) . c) xs ys
            
runLevels :: Applicative f => Levels f a -> f a
runLevels (Now x) = pure x
runLevels (Later f xs ys) = liftA2 f xs (runLevels ys)

bft :: forall f a b. Applicative f => (a -> f b) -> Tree a -> f (Tree b)
bft f = go
  where
    go (Node x xs) = liftA2 (\y (ys,_) -> Node y ys) (f x) (runLevels (foldr go' (Now ([],())) xs))
    
    go' :: forall c. Tree a -> Levels f ([Tree b], c) -> Levels f ([Tree b], c)
    go' (Node x xs) ys@(Now _)      = Later (\y -> first (pure . Node y)) (f x) (foldr go' ys xs)
    go' (Node x xs) (Later k ys zs) = Later id (liftA2 c (f x) ys) (foldr go' (fmap ((,) []) zs) xs)
      where
        c y g ~(ys,z) = first (Node y ys:) (k g z)

References

Easterly, Noah. 2019. “Functions and newtype wrappers for traversing Trees: Rampion/tree-traversals.” https://github.com/rampion/tree-traversals.
Gibbons, Jeremy. 2015. “Breadth-First Traversal.” Patterns in Functional Programming. https://patternsinfp.wordpress.com/2015/03/05/breadth-first-traversal/.
Okasaki, Chris. 2000. “Breadth-first Numbering: Lessons from a Small Exercise in Algorithm Design.” In Proceedings of the Fifth ACM SIGPLAN International Conference on Functional Programming, 131–136. ICFP ’00. New York, NY, USA: ACM. doi:10.1145/351240.351253. https://www.cs.tufts.edu/~nr/cs257/archive/chris-okasaki/breadth-first.pdf.