Deriving a Linear-Time Applicative Traversal of a Rose Tree
The Story so Far
Currently, we have several different ways to enumerate a tree in breadth-first order. The typical solution (which is the usual recommended approach in imperative programming as well) uses a queue, as described by Okasaki (2000). If we take the simplest possible queue (a list), we get a quadratic-time algorithm, with an albeit simple implementation. The next simplest version is to use a banker’s queue (which is just a pair of lists). From this version, if we inline and apply identities like the following:
foldr f b . reverse = foldl (flip f) b
We’ll get to the following definition:
bfe :: Forest a -> [a]
= foldr f b ts []
bfe ts where
Node x xs) fw bw = x : fw (xs : bw)
f (
= []
b [] = foldl (foldr f) b qs [] b qs
We can get from this function to others (like one which uses a corecursive queue, and so on) through a similar derivation. I might some day write a post on each derivation, starting from the simple version and demonstrating how to get to the more efficient at each step.
For today, though, I’m interested in the traversal of a rose tree. Traversal, here, of course, is in the applicative sense.
Thus far, I’ve managed to write linear-time traversals, but they’ve been unsatisfying. They work by enumerating the tree, traversing the effectful function over the list, and then rebuilding the tree. Since each of those steps only takes linear time, the whole thing is indeed a linear-time traversal, but I hadn’t been able to fuse away the intermediate step.
Phases
The template for the algorithm I want comes from the
Phases
applicative (Easterly 2019):
data Phases f a where
Lift :: f a -> Phases f a
(:<*>) :: f (a -> b) -> Phases f a -> Phases f b
We can use it to write a breadth-first traversal like so:
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (traverse go xs)) go (
The key component that makes this work is that it combines applicative effects in parallel:
instance Functor f => Functor (Phases f) where
fmap f (Lift x) = Lift (fmap f x)
fmap f (fs :<*> xs) = fmap (f.) fs :<*> xs
instance Applicative f => Applicative (Phases f) where
pure = Lift . pure
Lift fs <*> Lift xs = Lift (fs <*> xs)
:<*> gs) <*> Lift xs = liftA2 flip fs xs :<*> gs
(fs Lift fs <*> (xs :<*> ys) = liftA2 (.) fs xs :<*> ys
:<*> gs) <*> (xs :<*> ys) = liftA2 c fs xs :<*> liftA2 (,) gs ys
(fs where
~(x,y) = f x (g y) c f g
We’re also using the following helper functions:
runPhases :: Applicative f => Phases f a -> f a
Lift x) = x
runPhases (:<*> xs) = fs <*> runPhases xs
runPhases (fs
later :: Applicative f => Phases f a -> Phases f a
= (:<*>) (pure id) later
The problem is that it’s quadratic: the traverse
in:
Node x xs) = liftA2 Node (Lift (f x)) (later (traverse go xs)) go (
Hides some expensive calls to <*>
.
A Roadmap for Optimisation
The problem with the Phases
traversal is actually
analogous to another function for enumeration: levels
from
Gibbons
(2015).
levels :: Tree a -> [[a]]
Node x xs) = [x] : foldr lzw [] (map levels xs)
levels (where
= ys
lzw [] ys = xs
lzw xs [] :xs) (y:ys) = (x ++ y) : lzw xs ys lzw (x
lzw
takes the place of <*>
here, but
the overall issue is the same: we’re zipping at every point, making the
whole thing quadratic.
However, from the above function we can derive a linear time enumeration. It looks like this:
levels :: Tree a -> [[a]]
= f ts []
levels ts where
Node x xs) (q:qs) = (x:q) : foldr f qs xs
f (Node x xs) [] = [x] : foldr f [] xs f (
Our objective is clear, then: try to derive the linear-time
implementation of bft
from the quadratic, in a way
analogous to the above two functions. This is actually relatively
straightforward once the target is clear: the rest of this post is
devoted to the derivation.
Derivation
First, we start off with the original bft
.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (traverse go xs)) go (
Inline traverse
.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (go' xs))
go (= foldr (liftA2 (:) . go) (pure []) go'
Factor out go''
.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (go' xs))
go (= foldr go'' (pure [])
go' Node x xs) ys = liftA2 (:) (liftA2 Node (Lift (f x)) (later (go' xs))) ys go'' (
Inline go'
(and rename go''
to
go'
)
go'
(and rename go''
to
go'
)bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
go (Node x xs) ys = liftA2 (:) (liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))) ys go' (
Definition of liftA2
liftA2
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
go (Node x xs) ys = liftA2 (:) (fmap Node (f x) :<*> (foldr go' (pure []) xs)) ys go' (
Definition of liftA2
(pattern-matching on
ys
)
liftA2
(pattern-matching on
ys
)bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
go (Node x xs) (Lift ys) = fmap (((:).) . Node) (f x) :<*> (foldr go' (pure []) xs) <*> Lift ys
go' (Node x xs) (ys :<*> zs) = fmap (((:).) . Node) (f x) :<*> (foldr go' (pure []) xs) <*> ys :<*> zs go' (
Definition of <*>
.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
go (Node x xs) (Lift ys) = liftA2 flip (fmap (((:).) . Node) (f x)) ys :<*> foldr go' (pure []) xs
go' (Node x xs) (ys :<*> zs) = liftA2 c (fmap (((:).) . Node) (f x)) ys :<*> liftA2 (,) (foldr go' (pure []) xs) zs
go' (where
~(x,y) = f x (g y) c f g
Fuse liftA2
with fmap
liftA2
with fmap
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= runPhases . go
bft f where
Node x xs) = liftA2 Node (Lift (f x)) (later (foldr go' (pure []) xs))
go (Node x xs) (Lift ys) = liftA2 (flip . (((:).) . Node)) (f x) ys :<*> foldr go' (pure []) xs
go' (Node x xs) (ys :<*> zs) = liftA2 (c . (((:).) . Node)) (f x) ys :<*> liftA2 (,) (foldr go' (pure []) xs) zs
go' (where
~(x,y) = f x (g y) c f g
Beta-reduction.
bft :: Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= go
bft f where
Node x xs) = liftA2 Node (f x) (runPhases (foldr go' (pure []) xs))
go (
Node x xs) (Lift ys) = liftA2 (\y zs ys -> Node y ys : zs) (f x) ys :<*> foldr go' (pure []) xs
go' (Node x xs) (ys :<*> zs) = liftA2 c (f x) ys :<*> liftA2 (,) (foldr go' (pure []) xs) zs
go' (where
~(ys,z) = Node y ys : g z c y g
At this point, we actually hit a wall: the expression
foldr go' (pure []) xs) zs liftA2 (,) (
Is what makes the whole thing quadratic. We need to find a way to
thread that liftA2
along with the fold to get it to linear.
This is the only real trick in the derivation: I’ll use polymorphic
recursion to avoid the extra zip.
bft :: forall f a b. Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= go
bft f where
Node x xs) = liftA2 (\y (ys,_) -> Node y ys) (f x) (runPhases (foldr go' (pure ([],())) xs))
go (
go' :: forall c. Tree a -> Phases f ([Tree b], c) -> Phases f ([Tree b], c)
Node x xs) ys@(Lift _) = fmap (\y -> first (pure . Node y)) (f x) :<*> foldr go' ys xs
go' (Node x xs) (ys :<*> zs) = liftA2 c (f x) ys :<*> foldr go' (fmap ((,) []) zs) xs
go' (where
~(ys,z) = first (Node y ys:) (g z) c y g
And that’s it!
Avoiding Maps
We can finally write a slightly different version that avoids some
unnecessary fmap
s by basing Phases
on
liftA2
rather than <*>
.
data Levels f a where
Now :: a -> Levels f a
Later :: (a -> b -> c) -> f a -> Levels f b -> Levels f c
instance Functor f => Functor (Levels f) where
fmap f (Now x) = Now (f x)
fmap f (Later c xs ys) = Later ((f.) . c) xs ys
runLevels :: Applicative f => Levels f a -> f a
Now x) = pure x
runLevels (Later f xs ys) = liftA2 f xs (runLevels ys)
runLevels (
bft :: forall f a b. Applicative f => (a -> f b) -> Tree a -> f (Tree b)
= go
bft f where
Node x xs) = liftA2 (\y (ys,_) -> Node y ys) (f x) (runLevels (foldr go' (Now ([],())) xs))
go (
go' :: forall c. Tree a -> Levels f ([Tree b], c) -> Levels f ([Tree b], c)
Node x xs) ys@(Now _) = Later (\y -> first (pure . Node y)) (f x) (foldr go' ys xs)
go' (Node x xs) (Later k ys zs) = Later id (liftA2 c (f x) ys) (foldr go' (fmap ((,) []) zs) xs)
go' (where
~(ys,z) = first (Node y ys:) (k g z) c y g