A Different Probability Monad

Posted on September 27, 2016
Tags: ,

One of the more unusual monads is the “probability monad”:

{-# language PatternSynonyms, ViewPatterns #-}
{-# language DeriveFunctor, DeriveFoldable #-}
{-# language BangPatterns #-}

module Prob where

import Control.Arrow
import Data.Ratio
import Data.Foldable
newtype Probability a = Probability
  { runProb :: [(a,Rational)] }
  
data Coin = Heads | Tails

toss :: Probability Coin
toss = Probability [(Heads, 1 % 2), (Tails, 1 % 2)]

Although it’s a little inefficient, it’s an elegant representation. I’ve written about it before here.

It has some notable deficiencies, though. For instance: the user has to constantly check that all the probabilities add up to one. Its list can be empty, which doesn’t make sense. Also, individual outcomes can appear more than once in the same list.

A first go a fixing the problem might look something like this:

newtype Distrib a = Distrib
  { runDist :: [(a,Rational)] }

tossProb :: Distrib Coin
tossProb = Distrib [(Heads, 1), (Tails, 1)]

The type is the same as before: it’s the semantics which have changed. The second field of the tuples no longer have to add up to one. The list can still be empty, though, and now finding the probability of, say, the head, looks like this:

probHead :: Distrib a -> Rational
probHead (Distrib xs@((_,p):_)) = p / sum [ q | (_,q) <- xs ]

Infinite lists aren’t possible, either.

One other way to look at the problem is to mimic the structure of cons-lists. Something like this:

data Odds a = Certainly a
            | Odds a Rational (Odds a)
            deriving (Eq, Functor, Foldable, Show)

Here, the Odds constructor (analogous to (:)) contains the betting-style odds of the head element vs. the rest of the list. The coin from before is represented by:

tossOdds :: Odds Coin
tossOdds = Odds Heads (1 % 1) (Certainly Tails)

This representation has tons of nice properties. First, let’s use some pattern-synonym magic for rationals:

pattern (:%) :: Integer -> Integer -> Rational
pattern n :% d <- (numerator &&& denominator -> (n,d)) where
  n :% d = n % d

Then, finding the probability of the head element is this:

probHeadOdds :: Odds a -> Rational
probHeadOdds (Certainly _) = 1
probHeadOdds (Odds _ (n :% d) _) = n :% (n + d)

The representation can handle infinite lists no problem:

probHeadOdds (Odds 'a' (1 :% 1) undefined)
1 % 2

Taking the tail preserves semantics, also. To do some more involved manipulation, a fold helper is handy:

foldOdds :: (a -> Rational -> b -> b) -> (a -> b) -> Odds a -> b
foldOdds f b = r where
  r (Certainly x) = b x
  r (Odds x p xs) = f x p (r xs)

You can use this function to find the probability of a given item:

probOfEvent :: Eq a => a -> Odds a -> Rational
probOfEvent e = foldOdds f b where
  b x = if e == x then 1 else 0
  f x n r = (if e == x then n else r) / (n + 1)

This assumes that each item only occurs once. A function which combines multiple events might look like this:

probOf :: (a -> Bool) -> Odds a -> Rational
probOf p = foldOdds f b where
  b x = if p x then 1 else 0
  f x n r = (if p x then r + n else r) / (n + 1)

Some utility functions to create Odds:

equalOdds :: Foldable f => f a -> Maybe (Odds a)
equalOdds xs = case length xs of
  0 -> Nothing
  n -> Just (foldr f undefined xs (n - 1)) where
    f y a 0 = Certainly y
    f y a n = Odds y (1 % fromIntegral n) (a (n - 1))

fromDistrib :: [(a,Integer)] -> Maybe (Odds a)
fromDistrib [] = Nothing
fromDistrib xs = Just $ f (tot*lst) xs where
  (tot,lst) = foldl' (\(!t,_) e -> (t+e,e)) (0,undefined) (map snd xs)
  f _ [(x,_)] = Certainly x
  f n ((x,p):xs) = Odds x (mp % np) (f np xs) where
    mp = p * lst
    np = n - mp
                  
probOfEach :: Eq a => a -> Odds a -> Rational
probOfEach x xs = probOf (x==) xs

propOf :: Eq a => a -> [a] -> Maybe Rational
propOf _ [] = Nothing
propOf x xs = Just . uncurry (%) $
  foldl' (\(!n,!m) e -> (if x == e then n+1 else n, m+1)) (0,0) xs
propOf x xs == fmap (probOfEach x) (equalOdds xs)

And finally, the instances:

append :: Odds a -> Rational -> Odds a -> Odds a
append = foldOdds f Odds where
  f e r a p ys = Odds e ip (a op ys) where
    ip = p * r / (p + r + 1)
    op = p / (r + 1)

flatten :: Odds (Odds a) -> Odds a
flatten = foldOdds append id

instance Applicative Odds where
  pure = Certainly
  fs <*> xs = flatten (fmap (<$> xs) fs)
  
instance Monad Odds where
  x >>= f = flatten (f <$> x)