Countdown
There’s a popular UK TV show called Countdown with a round where contestants have to get as close to some target number as possible by constructing an arithmetic expression from six random numbers.
You don’t have to use all of the numbers, and you’re allowed use four operations: addition, subtraction, multiplication, and division. Additionally, each stage of the calculation must result in a positive integer.
Here’s an example. Try get to the target 586:
On the show, contestants get 30 seconds to think of an answer.
Solution
Solving it in Haskell was first explored in depth in Hutton (2002). There, a basic “generate-and-test” implementation was provided and proven correct.
As an optimization problem, there are several factors which will influence the choice of algorithm:
- There’s no obvious heuristic for constructing subexpressions in order to get to a final result. In other words, if we have and , there’s no easy way to tell which is “closer” to . The latter is closer numerically, but the former is what we ended up using in the solution.
- Because certain subexpressions aren’t allowed, we’ll be able to prune the search space as we go.
- Ideally, we’d only want to calculate each possible subexpression once, making it a pretty standard dynamic programming problem.
I’ll be focusing on the third point in this post, but we can add the second point in at the end. First, however, let’s write a naive implementation.
Generating all Expressions
I can’t think of a simpler way to solve the problem than
generate-and-test, so we’ll work from there. Testing is easy ((target ==) . eval
),
so we’ll focus on generation. The core function we’ll use for this is
usually called “unmerges”:
= [([x],[y])]
unmerges [x,y] :xs) =
unmerges (x:
([x],xs) concat
:ys,zs),(ys,x:zs)]
[ [(x| (ys,zs) <- unmerges xs ]
= [] unmerges _
It generates all possible 2-partitions of a list, ignoring order:
>>> unmerges "abc"
"a","bc"),("ab","c"),("b","ac")] [(
I haven’t looked much into how to optimize this function or make it nicer, as we’ll be swapping it out later.
Next, we need to make the recursive calls:
allExprs :: (a -> a -> [a]) -> [a] -> [a]
= [x]
allExprs _ [x] =
allExprs c xs
[ e| (ys,zs) <- unmerges xs
<- allExprs c ys
, y <- allExprs c zs
, z <- c y z ] , e
Finally, using the simple-reflect library, we can take a look at the output:
>>> allExprs (\x y -> [x+y,x*y]) [1,2] :: [Expr]
1 + 2,1 * 2]
[>>> allExprs (\x y -> [x+y]) [1,2,3] :: [Expr]
1 + (2 + 3),1 + 2 + 3,2 + (1 + 3)] [
Even at this early stage, we can actually already write a rudimentary solution:
countdown :: [Integer] -> Integer -> [Expr]
=
countdown xs targ filter
==) targ . toInteger)
((
(allExprs-> [x,y,x+y,x*y])
(\x y map fromInteger xs))
(
>>> mapM_ print (countdown [100,25,1,5,3,10] 586)
1 + (100 * 5 + (25 * 3 + 10))
1 + (100 * 5 + 25 * 3 + 10)
1 + (25 * 3 + (100 * 5 + 10))
1 + 100 * 5 + (25 * 3 + 10)
100 * 5 + (1 + (25 * 3 + 10))
100 * 5 + (1 + 25 * 3 + 10)
100 * 5 + (25 * 3 + (1 + 10))
1 + (100 * 5 + 25 * 3) + 10
1 + 100 * 5 + 25 * 3 + 10
100 * 5 + (1 + 25 * 3) + 10
100 * 5 + 25 * 3 + (1 + 10)
1 + 25 * 3 + (100 * 5 + 10)
25 * 3 + (1 + (100 * 5 + 10))
25 * 3 + (1 + 100 * 5 + 10)
25 * 3 + (100 * 5 + (1 + 10))
As you can see from the output, there’s a lot of repetition. We’ll need to do some memoization to speed it up.
Pure Memoization
The normal way most programmers think about “memoization” is something like this:
= {0:0,1:1}
memo_dict
def fib(n):
if n in memo_dict:
return memo_dict[n]
else:
= fib(n-1) + fib(n-2)
res = res
memo_dict[n] return res
In other words, it’s a fundamentally stateful process. We need to mutate some mapping when we haven’t seen the argument before.
Using laziness, though, we can emulate the same behavior purely. Instead of mutating the mapping on function calls, we fill the whole thing at the beginning, and then index into it. As long as the mapping is lazy, it’ll only evaluate the function calls when they’re needed. We could use lists as our mapping to the natural numbers:
= 0 : 1 : map fib [2..]
fibs = fibs !! (n-1) + fibs !! (n-2) fib n
The benefit here is that we avoid the extra work of redundant calls. However, we pay for the speedup in three ways:
- Space: we need to take up memory space storing the cached solutions.
- Indexing: while we no longer have to pay for the expensive recursive calls, we do now have to pay for indexing into the data structure. In this example, we’re paying linear time to index into the list.
- Generality: the memoization is tied directly to the argument type to
the function. We need to be able to use the argument to our memoized
function as an index into some data structure. While a lot of argument
types admit some type of indexing (whether they’re
Hashable
,Ord
, etc.), some don’t, and we can’t memoize those using this technique.
We’re going to look at a technique that allow us to somewhat mitigate 2 and 3 above, using something called a nexus.
Nexuses
The standard technique of memoization is focused on the arguments to the function, creating a concrete representation of them in memory to map to the results. Using nexuses, as described in Bird and Hinze (2003), we’ll instead focus on the function itself, creating a concrete representation of its call graph in memory. Here’s the call graph of Fibonacci:
1)=1
┌fib(2)=1┤
┌fib(0)=0
│ └fib(3)=2┤
┌fib(1)=1
│ └fib(4)=3┤
┌fib(1)=1
│ │ ┌fib(2)=1┤
│ └fib(0)=0
│ └fib(5)=5┤
┌fib(1)=1
│ │ ┌fib(2)=1┤
│ │ ┌fib(0)=0
│ │ │ └fib(3)=2┤
│ └fib(1)=1
│ └fib(6)=8┤
fib(1)=1
│ ┌fib(2)=1┤
│ ┌fib(0)=0
│ │ └fib(3)=2┤
│ ┌fib(1)=1
│ │ └fib(4)=3┤
└fib(1)=1
│ ┌fib(2)=1┤
└fib(0)=0 └fib(
Turning that into a concrete datatype wouldn’t do us much good: it still has the massively redundant computations in it. However, we can recognize that entire subtrees are duplicates of each other: in those cases, instead of creating both subtrees, we could just create one and have each parent point to it1:
5)=5┬────────┬fib(3)=2┬────────┬fib(1)=1
┌fib(6)=8┤ │ │ │ │
fib(4)=3┴────────┴fib(2)=1┴fib(0)=0 └────────┴fib(
This is a nexus. In Haskell, it’s not observably different from the other form, except that it takes up significantly less space. It’s also much quicker to construct.
If we use it to memoize fib
,
we’ll no longer be indexing on the argument: we’ll instead follow the
relevant branch in the tree to the subcomputation, which is just chasing
a pointer. It also means the argument doesn’t have to be constrained to
any specific type. Here’s how you’d do it:
data Tree
= Leaf
| Node
val :: Integer
{ left :: Tree
, right :: Tree}
,
fib :: Integer -> Integer
= val . go
fib where
0 = Node 0 Leaf Leaf
go 1 = Node 1 (Node 0 Leaf Leaf) Leaf
go = node t (left t) where t = go (n-1)
go n = Node (val l + val r) l r node l r
So this approach sounds amazing, right? No constraints on the argument type, no need to pay for indexing: why doesn’t everyone use it everywhere? The main reason is that figuring out a nexus for the call-graph is hard. In fact, finding an optimal one is NP-hard in general (Steffen and Giegerich 2006).
The second problem is that it’s difficult to abstract out. The
standard technique of memoization relies on building a mapping from keys
to values: about as bread-and-butter as it gets in programming. Even
more, we already know how to say “values of this type can be used
efficiently as keys in some mapping”: for Data.Map it’s Ord
, for
Data.HashMap it’s Hashable
. All
of this together means we can build a nice library for memoization which
exports the two following functions:
memoHash :: Hashable a => (a -> b) -> (a -> b)
memoOrd :: Ord a => (a -> b) -> (a -> b)
Building a nexus, however, is not bread-and-butter. On top of that,
it’s difficult to say something like “recursive functions of this
structure can be constructed using a nexus”. What’s the typeclass for
that? In comparison to the signatures above, the constraint will need to
be on the arrows, not the a
. Even talking about the structure of
recursive functions is regarded as somewhat of an advanced subject: that
said, the recursion-schemes
package allows us to do so, and even has facilities for constructing
something like nexuses with histomorphisms (Tobin 2016). I’m
still looking to see if there’s a library out there that does
manage to abstract nexuses in an ergonomic way, so I’d love to hear if
there was one (or if there’s some more generalized form which
accomplishes the same).
Memoizing Countdown
That’s enough preamble. The nexus we want to construct for countdown is not going to memoize as much as possible: in particular, we’re only going to memoize the shape of the trees, not the operators used. This will massively reduce the memory overhead, and still give a decent speedup (Bird and Mu 2005, 11 “building a skeleton tree first”).
With that in mind, the ideal nexus looks something like this:
We can represent the tree in Haskell as a rose tree:
data Tree a
= Node
root :: a
{ forest :: Forest a
,
}
type Forest a = [Tree a]
Constructing the nexus itself isn’t actually the most interesting
part of this solution: consuming it is. We need to be able to
go from the structure above into a list that’s the equivalent of unmerges
. Doing a breadth-first
traversal of the diagram above (without the top element) will give
us:
If you split that list in half, and zip it with its reverse, you’ll
get the output of unmerges
.
However, the breadth-first traversal of the diagram isn’t the same thing as the breadth-first traversal of the rose tree. The latter will traverse , and then the children of (), and then the children of (): and here’s our problem. We traverse twice, because we can’t know that and are pointing to the same value. What we have to do is first prune the tree, removing duplicates, and then perform a breadth-first traversal on that.
Pruning
Luckily, the duplicates follow a pattern, allowing us to remove them without having to do any equality checking. In each row, the first node has no duplicates in its children, the second’s first child is a duplicate, the third’s first and second children are duplicates, and so on. You should be able to see this in the diagram above. Adapting a little from the paper, we get an algorithm like this:
para :: (a -> [a] -> b -> b) -> b -> [a] -> b
= go
para f b where
= b
go [] :xs) = f x xs (go xs)
go (x
prune :: Forest a -> Forest a
= pruneAt ts 0
prune ts where
= para f (const [])
pruneAt Node x []) t _ _ = Node x [] : t
f (Node x us) _ a k =
f (Node x (pruneAt (drop k us) k) : a (k + 1)
Breadth-First Traversal
I went through this in a previous post, so this is the end solution:
breadthFirst :: Forest a -> [a]
= foldr f b ts []
breadthFirst ts where
Node x xs) fw bw = x : fw (xs:bw)
f (
= []
b [] = foldl (foldr f) b q [] b q
With the appropriate incantations, this is actually the fastest implementation I’ve found.
Fusing
We can actually inline both of the above functions, fusing them together:
spanNexus :: Forest a -> [a]
= foldr f (const b) ts 0 []
spanNexus ts where
Node x us) fw k bw = x : fw (k+1) ((drop k us, k) : bw)
f (
= []
b [] = foldl (uncurry . foldr f . const) b qs [] b qs
Halving, Convolving, and Folding
So, now we can go from the tree to our list of splits. Next step is to convert that list into the output of unmerges, by zipping the reverse of the first half with the second. We can use an algorithm described in Danvy and Goldberg (2005) to do the zipping and reversing:
= go xs n (const [])
fold xs n where
0 k = k xs
go xs :xs) n k = go xs (n-2) (\(y:ys) -> (x,y) : k ys) go (x
And we can inline the function which collapses those results into one:
= go xs n (const [])
fold xs n where
0 xss k = k xss
go :xss) k =
go n (xs-2) xss (\(ys:yss) -> [ z
go (n| x <- xs
<- ys
, y <- cmb x y
, z ++ k yss) ]
And that’s all we need!
Full Code
import qualified Data.Tree as Rose
data Tree a
= Leaf Int a
| Node [Tree a]
deriving (Show,Eq,Functor)
enumerateTrees :: (a -> a -> [a]) -> [a] -> [a]
= []
enumerateTrees _ [] = (extract . steps . initial) xs
enumerateTrees cmb xs where
= map nodes . group
step
= x
steps [x] = steps (step xs)
steps xs
= map (Leaf 1 . flip Rose.Node [] . pure)
initial
Leaf _ x) = Rose.rootLabel x
extract (Node [x]) = extract x
extract (
group [_] = []
group (Leaf _ x:vs) = Node [Leaf 2 [x, y] | Leaf _ y <- vs] : group vs
group (Node u:vs) = Node (zipWith comb (group u) vs) : group vs
Leaf n xs) (Leaf _ x) = Leaf (n + 1) (xs ++ [x])
comb (Node us) (Node vs) = Node (zipWith comb us vs)
comb (
= foldr f (const b) ts 0 []
forest ts where
Rose.Node x []) fw !k bw = x : fw (k + 1) bw
f (Rose.Node x us) fw !k bw = x : fw (k + 1) ((drop k us, k) : bw)
f (
= []
b [] = foldl (uncurry . foldr f . const) b qs []
b qs
Leaf n x) = Leaf 1 (node n x)
nodes (Node xs) = Node (map nodes xs)
nodes (
= Rose.Node (walk (2 ^ n - 2) (forest ts) (const [])) ts
node n ts where
0 xss k = k xss
walk :xss) k =
walk n (xs-2) xss (\(ys:yss) -> [ z
walk (n| x <- xs
<- ys
, y <- cmb x y
, z ++ k yss) ]
Using it for Countdown
The first thing to do for the Countdown solution is to figure out a representation for expressions. The one from simple-reflect is perfect for displaying the result, but we should memoize its calculation.
data Memoed
= Memoed
expr :: Expr
{ result :: Int
, }
Then, some helpers for building:
data Op = Add | Dif | Mul | Div
= Memoed ((f `on` expr) x y) ((g `on` result) x y)
binOp f g x y
apply :: Op -> Memoed -> Memoed -> Memoed
Add x y = binOp (+) (+) x y
apply Dif x y
apply | result y < result x = binOp (-) (-) x y
| otherwise = binOp (-) (-) y x
Mul x y = binOp (*) (*) x y
apply Div x y = binOp div div x y apply
Finally, the full algorithm:
enumerateExprs :: [Int] -> [Memoed]
= enumerateTrees cmb . map (\x -> Memoed (fromIntegral x) x)
enumerateExprs where
=
cmb x y $
nubs :
x :
y
[ apply op x y| op <- [Add, Dif, Mul, Div]
, legal op (result x) (result y) ]Add _ _ = True
legal Dif x y = x /= y
legal Mul _ _ = True
legal Div x y = x `mod` y == 0
legal = foldr f (const []) xs IntSet.empty
nubs xs where
f e a s| IntSet.member (result e) s = a s
| otherwise = e : a (IntSet.insert (result e) s)
countdown :: Int -> [Int] -> [Expr]
= map expr . filter ((==) targ . result) . enumerateExprs
countdown targ
>>> (mapM_ print . reduction . head) (countdown 586 [100,25,1,5,3,10])
25 * 3 + 1 + (100 * 5 + 10)
75 + 1 + (100 * 5 + 10)
76 + (100 * 5 + 10)
76 + (500 + 10)
76 + 510
586
There are some optimizations going on here, taken mainly from Bird and Mu (2005):
- We filter out illegal operations, as described originally.
- We filter out any expressions that have the same value.
Testing the Implementation
So we’ve followed the paper, written the code: time to test. The specification of the function is relatively simple: calculate all applications of the commutative operator to some input, without recalculating subtrees.
We’ll need a free structure for the “commutative operator”:
data Tree a
= Leaf a
| Tree a :^: Tree a
deriving (Foldable,Eq,Ord,Show)
Here’s the problem: it’s not commutative! We can remedy it by only exporting a constructor that creates the tree in a commutative way, and we can make it a pattern synonym so it looks normal:
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE PatternSynonyms #-}
module Commutative
Tree(Leaf)
(pattern (:*:))
,where
data Tree a
= Leaf a
| Tree a :^: Tree a
deriving (Eq,Ord,Show,Foldable)
pattern (:*:) :: Ord a => Tree a -> Tree a -> Tree a
pattern xs :*: ys <- xs :^: ys where
:*: ys
xs | xs <= ys = xs :^: ys
| otherwise = ys :^: xs
{-# COMPLETE Leaf, (:*:) #-}
Now we need to check if all applications are actually tested. First, to generate all trees:
allTrees :: Ord a => [a] -> Set (Tree a)
= Set.singleton (Leaf x)
allTrees [x] = Set.unions (map (uncurry f) (unmerges xs))
allTrees xs where
= Set.fromList ((liftA2 (:*:) `on` (Set.toList . allTrees)) ls rs)
f ls rs
allSubTrees :: Ord a => [a] -> Set (Tree a)
= Set.singleton (Leaf x)
allSubTrees [x] =
allSubTrees xs map (uncurry f . (allSubTrees *** allSubTrees)) (unmerges xs))
Set.unions (where
=
f ls rs
Set.unions:*:) `on` Set.toList) ls rs)] [ls, rs, Set.fromList ((liftA2 (
Then, to test:
prop_exhaustiveSearch :: Natural -> Bool
=
prop_exhaustiveSearch n let src = [0 .. fromIntegral n]
= allSubTrees src
expect =
actual
Set.fromList
(enumerateTrees->
(\xs ys :*: ys])
[xs, ys, xs map Leaf src))
(in expect == actual
prop_exhaustiveSearchFull :: Natural -> Bool
=
prop_exhaustiveSearchFull n let src = [0 .. fromIntegral n]
= Map.fromSet (const 1) (allTrees src)
expect =
actual
freqs
(enumerateTrees-> [xs :*: ys])
(\xs ys map Leaf src))
(in expect == actual
Testing for repeated calls is more tricky. Remember, the memoization is supposed to be unobservable: in order to see it, we’re going to have to use some unsafe operations.
traceSubsequences :: ((Tree Int -> Tree Int -> [Tree Int]) -> [Tree Int] -> [Tree Int])
-> [Int]
-> (Map (Tree Int) Int, [Tree Int])
=
traceSubsequences enm ints $
runST do ref <- newSTRef Map.empty
let res = enm (combine ref) (map (conv ref) ints)
foldr seq (pure ())) res
traverse_ (<- readSTRef ref
intm pure (intm, res)
where
= unsafeRunST ([xs :*: ys] <$ modifySTRef' ref (incr (xs :*: ys)))
combine ref xs ys {-# NOINLINE combine #-}
= unsafeRunST (Leaf x <$ modifySTRef' ref (incr (Leaf x)))
conv ref x {-# NOINLINE conv #-}
= unsafePerformIO (unsafeSTToIO cmp)
unsafeRunST cmp
prop_noRepeatedCalls :: Property
=
prop_noRepeatedCalls $ sized $
property ->
\n pure $
let src = [0 .. n]
= fmap freqs (traceSubsequences enumerateTrees src)
(tint,tres) = fmap freqs (traceSubsequences dummyEnumerate src)
(fint,fres) in counterexample
(mapCompare (freqs (allSubTrees src)) tint)all (1 ==) tint) .&&.
(== fres) .&&.
counterexample (mapCompare tres fres) (tres > 2 ==> tint /= fint) (n
Here, dummyEnumerate
is some
method which performs the same task, but doesn’t construct a
nexus, so we can ensure that our tests really do catch faulty
implementations.
If you think that structure looks more like a funny linked list than a tree, that’s because it is. Instead of talking about “left” and “right” branches, we could talk about the first and second elements in a list: in fact, this is exactly what’s happening in the famous
zipWith
Fibonacci implementation (in reverse).= 0 : 1 : zipWith (+) fibs (tail fibs) fibs
Or, in my favourite version:
↩︎= fix ((:) 0 . scanl (+) 1) !! n fib n