{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Spec where
import qualified Control.Monad.State as CMS
import Control.Monad.State (MonadState, get, put)
import Control.Applicative -- not required for the exam
import Test.QuickCheck -- not needed for the exam question
import qualified Test.QuickCheck.Property as QC
import qualified Test.QuickCheck.Arbitrary as QC

------
-- Exam question, supporting code
-- The four MonadState laws (specialised to (S2 s a) to avoid ambiguous types):
putput :: Eq s => s -> s -> s -> Bool
putput s' s = 
  (put s' >> put s)                      =.=  put s

putget :: Eq s => s -> s -> Bool
putget s = 
  (put s >> get)                         =.=  (put s >> return s)

getput :: Eq s => S2 s () -> s -> Bool
getput phony = -- the first argument is only to disambiguate the types
  (get >>= put)                          =.=  (skip `asTypeOf` phony)

getget :: (Eq a, Eq s) => (s -> s -> S2 s a) -> s -> Bool
getget k = 
  (get >>= \s -> get >>= \s' -> k s s')  =.=  (get >>= \s -> k s s)

skip :: Monad m => m ()
skip = return ()

-- Consider the following instance:
data S2 s a where
  Return  ::  a -> S2 s a
  Bind    ::  S2 s a -> (a -> S2 s b) -> S2 s b
  Then    ::  S2 s a -> S2 s b -> S2 s b
  Get     ::  S2 s s
  Put     ::  s -> S2 s ()

instance Monad (S2 s) where {return = Return; (>>=) = Bind; (>>) = Then}
instance MonadState s (S2 s) where {get = Get; put = Put}

{- Task a:  
  Implement a run function |runS2 :: S2 s a -> (s -> (a, s))| and
  prove (by equational reasoning) that the put-put and put-get laws
  hold if |(==)| means ``all runs are equal''.
-}

-- Run function (part of the answer to Task a):

runS2 :: S2 s a   -> (s -> (a, s))
runS2 (Return a)  =  \s -> (a, s)
runS2 (Bind m f)  =  \s -> let (a, s') = runS2 m s  in runS2 (f a) s'
runS2 (Then m m') =  \s -> let (_, s') = runS2 m s  in runS2 m' s'
runS2 Get         =  \s -> (s,  s)
runS2 (Put s)     =  \_ -> ((), s)

{- For the proof part of Task a we apply runS2 to the lhs and rhs in
the same starting state and show equality by equational reasoning.
For testing and type-checking purposes I write the steps of the proof
in a Haskell list - this is not required in the exam.  -}

-- Proof of putput s' s =  (put s' >> put s) =.=  put s
putput_proof :: s -> s -> s -> AllEq ((), s)
putput_proof s1 s2 s3 = AllEq  
  [ runS2 (put s1 >> put s2) s3
  ,  -- def. of put, (>>)
    runS2 (Then (Put s1) (Put s2)) s3
  ,  -- def. of runS2 for Then
    (\s -> let (_, s') = runS2 (Put s1) s  in runS2 (Put s2) s') s3
  ,  -- def. of runS2 for Put (twice)
    (\s -> let (_, s') = ((),s1) in ((),s2)) s3
  ,  -- beta-red. & simplify unreachable binding
    ((), s2)
  ,  -- beta-red. (from here all steps are "backwards")
    (\_ -> ((), s2)) s3
  ,  -- def. of runS2 for Put
    runS2 (Put s2) s3
  ,  -- def. of put
    runS2 (put s2) s3
  ]

-- Proof of put-get:  (put s >> get) ==  (put s >> return s)
putget_proof :: Eq s => s -> s -> AllEq (s, s)
putget_proof s1 s2 = AllEq
  [ runS2 (put s1 >> get) s2
  , -- def. of put, get, (>>)
    runS2 (Then (Put s1) Get) s2
  , -- def. of runS2 for Then
    (\s -> let (_, s') = runS2 (Put s1) s  in runS2 Get s') s2
  , -- def. of runS2 for Put and beta-red.
    let (_, s') = ((), s1)  in runS2 Get s'
  , -- inline
    runS2 Get s1
  , -- def. of runS2 for Get and beta-red.
    (s1, s1)
  , -- def. of runS2 for Return and beta-red.
    runS2 (Return s1) s1
  , -- inline
    let (_, s') = ((), s1)  in runS2 (Return s1) s'
  , -- def. of runS2 for Put and beta-red.
    (\s -> let (_, s') = runS2 (Put s1) s  in runS2 (Return s1) s') s2
  , -- def. of runS2 for Then
    runS2 (Then (Put s1) (Return s1)) s2
  , -- def. put, return, (>>)
    runS2 (put s1 >> return s1) s2
  ]

-- Task (b): Program transformation based on reasoning

data S3 s a where
  Ret3     ::  a -> S3 s a
  GetBind  ::  (s -> S3 s a)  ->  S3 s a
  PutThen  ::  s -> S3 s a    ->  S3 s a

opt :: S2 s a -> S3 s a
opt (Return a)    =  Ret3 a
opt Get           =  get3
opt (Put s)       =  put3 s
opt (Bind m f)    =  removeBind m f
opt (Then m n)    =  removeThen m n

put3 :: s -> S3 s ()
put3 s = PutThen s (Ret3 ())

get3 :: S3 s s
get3 = GetBind Ret3  

-- Task: Implement |removeBind| and |removeThen| and motivate your definitions.
removeBind :: S2 s a -> (a -> S2 s b) -> S3 s b
removeBind (Return a)  f  =  opt (f a)                        -- Monad law 1
removeBind Get         f  =  GetBind (opt . f)                -- new constructor
removeBind (Put s)     f  =  PutThen s (opt (f ()))           -- new constructor
removeBind (Then m n)  f  =  opt (Then m (Bind n f))          -- Monad law 3'
removeBind (Bind m f)  g  =  opt (Bind m (\a-> Bind (f a) g)) -- Monad law 3

removeThen :: S2 s a -> S2 s b -> S3 s b
removeThen (Return a)  n  =  opt n                            -- Monad law 1'
removeThen Get         n  =  opt n                            -- prop. of Get
removeThen (Put s)     n  =  PutThen s (opt n)                -- new constructor
removeThen (Then m n)  o  =  opt (Then m (Then n o))          -- Monad law 3''
removeThen (Bind m f)  n  =  opt (Bind m (\a-> Then (f a) n)) -- Monad law 3'''


{-
-- Alternative right-hand-sides for some of the cases:
removeBind (Then m n)  f  =  removeThen m (Bind n f)
removeBind (Bind m f)  g  =  removeBind m (\a-> Bind (f a) g)

removeThen (Then m n)  o  =  removeThen m (Then n o)
removeThen (Bind m f)  n  =  removeBind m (\a-> Then (f a) n)
-}

removeBindThen ::
  S2 s a1 -> S2 s a2 -> (a2 -> S2 s a) -> AllEq (S3 s a)
removeBindThen m n f = AllEq
  [ removeBind (Then m n) f  
  , -- def. of opt for Bind
    opt (Bind (Then m n) f)
  , -- Monad law 3 (specialised for Then (>>))
    opt (Then m (Bind n f))         
  , -- def. of opt for Then
    removeThen m (Bind n f)
  ]

removeBindBind m f g = AllEq
  [ removeBind (Bind m f) g
  , -- def. opt for Bind
    opt (Bind (Bind m f) g)
  , -- Monad law 3
    opt (Bind m (\a-> Bind (f a) g))
  , -- def. of opt for Bind
    removeBind m (\a-> Bind (f a) g)
  ]

removeThenThen m n o = AllEq
  [ removeThen (Then m n)  o  
  , -- Monad law 3''
    opt (Then m (Then n o))
  , -- def. of opt for Then
    removeThen m (Then n o)
  ]

removeThenBind m f n = AllEq
  [ removeThen (Bind m f)  n  
  , -- Monad law 3'''
    opt (Bind m (\a-> Then (f a) n))
  , -- def. of opt for Bind
    removeBind m (\a-> Then (f a) n)
  ]

-- --------------
-- Below is some supporting code to sanity-check the exam answers.

runS3 :: S3 s a      -> (s -> (a, s))
runS3 (Ret3 a)       =  \s -> (a, s) 
runS3 (GetBind f)    =  \s -> runS3 (f s) s
runS3 (PutThen s m)  =  \_ -> runS3 m s

saneOpt m = \s -> runS2 m s == runS3 (opt m) s

testOpt = quickCheck (saneOpt . compileMState)

-- examples:
inp1 = MThen (MPut True) (MBind (MThen (MPut True) (MReturn False)) (A2MState (MThen (MPut True) (MReturn False)) (MReturn True)))
inp2 = MThen (MPut False) (MReturn True)

-- S3 s is also a state monad
instance Monad (S3 s) where {return = Ret3; (>>=) = bind3}
instance MonadState s (S3 s) where {put = put3; get = get3}

bind3 :: S3 s a -> (a -> S3 s b) -> S3 s b
bind3 (Ret3 a)       f  =  f a
bind3 (GetBind f)    g  =  GetBind (\s-> bind3 (f s) g)
bind3 (PutThen s m)  f  =  PutThen s (bind3 m f)
  

-- Now, how can we use the put-* laws?
opt3 :: S3 s a -> S3 s a  
opt3 (PutThen _ (PutThen s m))  =  PutThen s m  -- put-put
opt3 (PutThen s (GetBind f))    =  PutThen s (opt3 (f s)) -- put-get

saneOpt3 m = \s -> runS3 m' s == runS3 (opt3 m') s
  where m' = opt m

testOpt3 = quickCheck (saneOpt . compileMState)

-- --------------




-- --------------------------------------------------------------
-- Equality check:
(=.=) :: (Eq a, Eq s) => S2 s a -> S2 s a -> s -> Bool
m =.= n = \s -> runS2 m s == runS2 n s

type S = Bool
getget' :: Fun -> S -> Bool
getget' (Fun fun) = getget (\s s' -> compileMState (fun s s'))

putput' :: S -> S -> S -> Bool
putput' = putput

putget' :: S -> S -> Bool
putget' = putget

getput' :: S -> Bool
getput' = getput (undefined :: S2 S ())
  
test1 = do quickCheck getget'
           quickCheck putput'
           quickCheck putget'
           quickCheck getput'

newtype Fun = Fun {unFun :: S -> S -> MState}
  deriving (Arbitrary)

instance Show Fun where
  show = showFun
  
showFun (Fun f) = concatMap show [f x y | x <- [False, True], y <- [False, True]]

instance (Bounded s, Enum s, Show s, Show a) => Show (S2 s a) where
  show = showS2

showS2 sm = concatMap show [g s | s <- [minBound..maxBound]]
  where g = runS2 sm
        



{-
-- A polymorphic instance does not work with the GADT S2
instance (Arbitrary s, Arbitrary a) => Arbitrary (S2 s a) where
  arbitrary = arbitraryS2
  
arbitraryS2 :: (Arbitrary s, Arbitrary a) => Gen (S2 s a)
arbitraryS2 = oneof 
  [ Return <$> arbitrary
  , Bind   <$> arbitrary <*> arbitrary  -- ambiguous
  , Then   <$> arbitrary <*> arbitrary  -- ambiguous
  , pure Get                            -- type mismatch
  , Put    <$> arbitrary                -- type mismatch
  ]
-}

type A = Bool
data MState where -- a result-monomorphic version of S2 s a
  MReturn  ::  A -> MState
  MBind    ::  MState -> (A2MState) -> MState
  MThen    ::  MState' -> MState -> MState
  MGet     ::  MState
  deriving Show

data MState' where
  MPut     ::  S -> MState'
  deriving Show

compileMState' :: MState' -> S2 S ()
compileMState' (MPut s)     =  Put s


data A2MState = A2MState MState MState -- False and True cases
  deriving Show  

compileMState :: MState -> S2 S A
compileMState (MReturn a)  =  Return a
compileMState (MBind m f)  =  Bind (compileMState m) (compileA2MState f)
compileMState (MThen m n)  =  Then (compileMState' m) (compileMState n)
compileMState (MGet)       =  Get
  
compileA2MState :: A2MState -> (A -> S2 S A)
compileA2MState (A2MState f t) a = compileMState $ if a then t else f

instance Arbitrary MState where
  arbitrary = arbitraryMState
  
instance Arbitrary MState' where
  arbitrary = arbitraryMState'

instance Arbitrary A2MState where
  arbitrary = arbitraryA2MState

arbitraryMState' :: Gen MState'
arbitraryMState' = MPut <$> arbitrary

arbitraryMState :: Gen MState
arbitraryMState = oneof 
  [ MReturn <$> arbitrary
  , MBind   <$> arbitrary <*> arbitrary
  , MThen   <$> arbitrary <*> arbitrary
  , pure MGet                            -- type cheat - state S = result type A
  ]

arbitraryA2MState :: Gen A2MState
arbitraryA2MState = A2MState <$> arbitrary <*> arbitrary


----------------------------------------------------------------
-- Sanity checking of the "proofs" - not needed on the exam:

-- Straw-man proofs in Haskell (just a list of supposedly equal expr.)
newtype AllEq a = AllEq [a] 
  deriving (Arbitrary)
{- 
mytest (AllEq ms) = forAll arbitrary $ \e -> 
                    allEq $ map (flip runS2 e) ms

instance (Arbitrary s, Show s, Eq a, Eq s) =>
         Testable (AllEq (S2 s a)) where 
  property = mytest
-}

instance Eq a => Testable (AllEq a) where
  property = propertyAllEq 
                          
propertyAllEq :: Eq a => AllEq a -> Property
propertyAllEq (AllEq xs) = QC.property $ QC.liftBool $ allEq xs
    
allEq []      =  True
allEq (x:xs)  =  all (x==) xs

    

test2 = do quickCheck (\s -> putput_proof (s :: Int))
           quickCheck (\s -> putget_proof (s :: Int))

main = test1 >> test2