{-# LANGUAGE FlexibleContexts, GADTs, TypeFamilies #-}

module Examples.Feldspar where



import Prelude hiding (max, min)
import qualified Prelude

import Data.Typeable

import Lambda
import Frontend



--------------------------------------------------------------------------------
-- * Types
--------------------------------------------------------------------------------

-- | The set of supported types
class (Eq a, Show a, Typeable a) => Type a

instance Type Bool
instance Type Int
instance Type Float
instance Type a => Type [a]  -- Type of arrays
instance (Type a, Type b) => Type (a,b)

type Length = Int
type Index  = Int



--------------------------------------------------------------------------------
-- * General tuple expressions
--------------------------------------------------------------------------------

data Tuple expr a
  where
    Fst     :: Tuple expr ((a,b) -> a)
    Snd     :: Tuple expr ((a,b) -> b)
    Pair    :: Tuple expr (a -> b -> (a,b))
    InjPair :: expr a -> Tuple expr a

instance Eval expr => Eval (Tuple expr)
  where
    eval Fst         = fst
    eval Snd         = snd
    eval Pair        = (,)
    eval (InjPair a) = eval a

instance ExprEq expr => ExprEq (Tuple expr)
  where
    Fst       `exprEq` Fst       = True
    Snd       `exprEq` Snd       = True
    Pair      `exprEq` Pair      = True
    InjPair a `exprEq` InjPair b = a `exprEq` b
    _ `exprEq` _                 = False

instance ExprShow expr => ExprShow (Tuple expr)
  where
    exprShow Fst         = "fst"
    exprShow Snd         = "snd"
    exprShow Pair        = "pair"
    exprShow (InjPair a) = exprShow a



--------------------------------------------------------------------------------
-- * The Feldspar domain
--------------------------------------------------------------------------------

data Feldspar a
  where
    Literal :: Type a => a -> Feldspar a

    Function :: String -> (a -> b) -> Feldspar (a -> b)

    Condition :: Feldspar (Bool -> a -> a -> a)

    Parallel :: Feldspar (Length -> (Index -> a) -> [a])

    ForLoop :: Feldspar (Length -> st -> (Index -> st -> st) -> st)

data Data a = Data { unData :: Lam (Tuple Feldspar) a }

instance Eval Feldspar
  where
    eval (Literal a)    = a
    eval (Function _ f) = f
    eval Condition      = \cond tHEN eLSE -> if cond then tHEN else eLSE
    eval Parallel       = \len ixf -> Prelude.map ixf [0 .. len-1]
    eval ForLoop        = \len init body -> foldr body init [0 .. len-1]

instance ExprEq Feldspar
  where
    Literal a     `exprEq` Literal b     = case cast a of
                                             Just a' -> a'==b
                                             Nothing -> False
    Function f1 _ `exprEq` Function f2 _ = f1==f2
    Condition     `exprEq` Condition     = True
    Parallel      `exprEq` Parallel      = True
    ForLoop       `exprEq` ForLoop       = True
    _ `exprEq` _                         = False

instance ExprShow Feldspar
  where
    exprShow (Literal a)       = show a
    exprShow (Function name _) = name
    exprShow Condition         = "condition"
    exprShow Parallel          = "parallel"
    exprShow ForLoop           = "forLoop"



injFeldspar :: Feldspar a -> Lam (Tuple Feldspar) a
injFeldspar = inject . InjPair

lit :: Type a => a -> Data a
lit = Data . injFeldspar . Literal

function :: Type a => String -> (a -> b) -> Data a -> Data b
function name f a = Data $ injFeldspar (Function name f) $$ unData a

function2 :: (Type a, Type b) =>
    String -> (a -> b -> c) -> Data a -> Data b -> Data c
function2 name f a b =
    Data $ injFeldspar (Function name f) $$ unData a $$ unData b

instance Eval Data
  where
    eval = eval . unData

instance Eq (Data a)
  where
    a == b = unData a == unData b

instance Show (Data a)
  where
    show = show . unData

instance (Type a, Num a) => Num (Data a)
  where
    fromInteger = lit . fromInteger
    abs         = function "abs" abs
    signum      = function "signum" signum
    (+)         = function2 "(+)" (+)
    (-)         = function2 "(-)" (-)
    (*)         = function2 "(*)" (*)

fST :: (Type a, Type b) => Data (a,b) -> Data a
fST ab = Data $ inject Fst $$ unData ab

sND :: (Type a, Type b) => Data (a,b) -> Data b
sND ab = Data $ inject Snd $$ unData ab

pair :: (Type a, Type b) => Data a -> Data b -> Data (a,b)
pair a b = Data $ inject Pair $$ unData a $$ unData b

condition' :: Type a => Data Bool -> Data a -> Data a -> Data a
condition' cond tHEN eLSE =
    Data $ injFeldspar Condition $$ unData cond $$ unData tHEN $$ unData eLSE

parallel :: Type a => Data Length -> (Data Index -> Data a) -> Data [a]
parallel len ixf =
    Data $ injFeldspar Parallel $$ unData len $$ lambda (unData . ixf . Data)

forLoop' :: Typeable st =>
     Data Length -> Data st -> (Data Index -> Data st -> Data st) -> Data st
forLoop' len init body =
    Data $ injFeldspar ForLoop $$ unData len $$ unData init $$ body'
  where
    body' = lambda $ \i -> lambda $ \st -> unData $ body (Data i) (Data st)

class Type (Internal a) => Computable a
  where
    type Internal a
    internalize :: a -> Data (Internal a)
    externalize :: Data (Internal a) -> a

instance Type a => Computable (Data a)
  where
    type Internal (Data a) = a
    internalize = id
    externalize = id

instance (Computable a, Computable b) => Computable (a,b)
  where
    type Internal (a,b) = (Internal a, Internal b)
    internalize (a,b)   = internalize a `pair` internalize b
    externalize ab      = (externalize (fST ab), externalize (sND ab))

force :: Computable a => a -> a
force = externalize . internalize

evalComputable :: Computable a => a -> Internal a
evalComputable = eval . internalize

printComputable :: Computable a => a -> IO ()
printComputable = print . runLambda . unData . internalize

drawComputable :: Computable a => a -> IO ()
drawComputable = drawLambda . runLambda . unData . internalize

condition :: Computable a => Data Bool -> a -> a -> a
condition cond tHEN eLSE = externalize $
    condition' cond (internalize tHEN) (internalize eLSE)

forLoop :: Computable st => Data Length -> st -> (Data Index -> st -> st) -> st
forLoop len init body = externalize $ forLoop' len (internalize init) body'
  where
    body' i st = internalize $ body i (externalize st)

arrLength :: Type a => Data [a] -> Data Length
arrLength = function "arrLength" Prelude.length

getIx :: Type a => Data [a] -> Data Index -> Data a
getIx arr ix = function2 "getIx" eval arr ix
  where
    eval as i
        | i >= len || i < 0 = error "getIx: index out of bounds"
        | otherwise         = as !! i
      where
        len = Prelude.length as

max :: (Type a, Ord a) => Data a -> Data a -> Data a
max = function2 "max" Prelude.max

min :: (Type a, Ord a) => Data a -> Data a -> Data a
min = function2 "min" Prelude.min