{-# LANGUAGE FlexibleContexts, GADTs, TypeFamilies #-} module Examples.Feldspar where import Prelude hiding (max, min) import qualified Prelude import Data.Typeable import Lambda import Frontend -------------------------------------------------------------------------------- -- * Types -------------------------------------------------------------------------------- -- | The set of supported types class (Eq a, Show a, Typeable a) => Type a instance Type Bool instance Type Int instance Type Float instance Type a => Type [a] -- Type of arrays instance (Type a, Type b) => Type (a,b) type Length = Int type Index = Int -------------------------------------------------------------------------------- -- * General tuple expressions -------------------------------------------------------------------------------- data Tuple expr a where Fst :: Tuple expr ((a,b) -> a) Snd :: Tuple expr ((a,b) -> b) Pair :: Tuple expr (a -> b -> (a,b)) InjPair :: expr a -> Tuple expr a instance Eval expr => Eval (Tuple expr) where eval Fst = fst eval Snd = snd eval Pair = (,) eval (InjPair a) = eval a instance ExprEq expr => ExprEq (Tuple expr) where Fst `exprEq` Fst = True Snd `exprEq` Snd = True Pair `exprEq` Pair = True InjPair a `exprEq` InjPair b = a `exprEq` b _ `exprEq` _ = False instance ExprShow expr => ExprShow (Tuple expr) where exprShow Fst = "fst" exprShow Snd = "snd" exprShow Pair = "pair" exprShow (InjPair a) = exprShow a -------------------------------------------------------------------------------- -- * The Feldspar domain -------------------------------------------------------------------------------- data Feldspar a where Literal :: Type a => a -> Feldspar a Function :: String -> (a -> b) -> Feldspar (a -> b) Condition :: Feldspar (Bool -> a -> a -> a) Parallel :: Feldspar (Length -> (Index -> a) -> [a]) ForLoop :: Feldspar (Length -> st -> (Index -> st -> st) -> st) data Data a = Data { unData :: Lam (Tuple Feldspar) a } instance Eval Feldspar where eval (Literal a) = a eval (Function _ f) = f eval Condition = \cond tHEN eLSE -> if cond then tHEN else eLSE eval Parallel = \len ixf -> Prelude.map ixf [0 .. len-1] eval ForLoop = \len init body -> foldr body init [0 .. len-1] instance ExprEq Feldspar where Literal a `exprEq` Literal b = case cast a of Just a' -> a'==b Nothing -> False Function f1 _ `exprEq` Function f2 _ = f1==f2 Condition `exprEq` Condition = True Parallel `exprEq` Parallel = True ForLoop `exprEq` ForLoop = True _ `exprEq` _ = False instance ExprShow Feldspar where exprShow (Literal a) = show a exprShow (Function name _) = name exprShow Condition = "condition" exprShow Parallel = "parallel" exprShow ForLoop = "forLoop" injFeldspar :: Feldspar a -> Lam (Tuple Feldspar) a injFeldspar = inject . InjPair lit :: Type a => a -> Data a lit = Data . injFeldspar . Literal function :: Type a => String -> (a -> b) -> Data a -> Data b function name f a = Data $ injFeldspar (Function name f) $$ unData a function2 :: (Type a, Type b) => String -> (a -> b -> c) -> Data a -> Data b -> Data c function2 name f a b = Data $ injFeldspar (Function name f) $$ unData a $$ unData b instance Eval Data where eval = eval . unData instance Eq (Data a) where a == b = unData a == unData b instance Show (Data a) where show = show . unData instance (Type a, Num a) => Num (Data a) where fromInteger = lit . fromInteger abs = function "abs" abs signum = function "signum" signum (+) = function2 "(+)" (+) (-) = function2 "(-)" (-) (*) = function2 "(*)" (*) fST :: (Type a, Type b) => Data (a,b) -> Data a fST ab = Data $ inject Fst $$ unData ab sND :: (Type a, Type b) => Data (a,b) -> Data b sND ab = Data $ inject Snd $$ unData ab pair :: (Type a, Type b) => Data a -> Data b -> Data (a,b) pair a b = Data $ inject Pair $$ unData a $$ unData b condition' :: Type a => Data Bool -> Data a -> Data a -> Data a condition' cond tHEN eLSE = Data $ injFeldspar Condition $$ unData cond $$ unData tHEN $$ unData eLSE parallel :: Type a => Data Length -> (Data Index -> Data a) -> Data [a] parallel len ixf = Data $ injFeldspar Parallel $$ unData len $$ lambda (unData . ixf . Data) forLoop' :: Typeable st => Data Length -> Data st -> (Data Index -> Data st -> Data st) -> Data st forLoop' len init body = Data $ injFeldspar ForLoop $$ unData len $$ unData init $$ body' where body' = lambda $ \i -> lambda $ \st -> unData $ body (Data i) (Data st) class Type (Internal a) => Computable a where type Internal a internalize :: a -> Data (Internal a) externalize :: Data (Internal a) -> a instance Type a => Computable (Data a) where type Internal (Data a) = a internalize = id externalize = id instance (Computable a, Computable b) => Computable (a,b) where type Internal (a,b) = (Internal a, Internal b) internalize (a,b) = internalize a `pair` internalize b externalize ab = (externalize (fST ab), externalize (sND ab)) force :: Computable a => a -> a force = externalize . internalize evalComputable :: Computable a => a -> Internal a evalComputable = eval . internalize printComputable :: Computable a => a -> IO () printComputable = print . runLambda . unData . internalize drawComputable :: Computable a => a -> IO () drawComputable = drawLambda . runLambda . unData . internalize condition :: Computable a => Data Bool -> a -> a -> a condition cond tHEN eLSE = externalize $ condition' cond (internalize tHEN) (internalize eLSE) forLoop :: Computable st => Data Length -> st -> (Data Index -> st -> st) -> st forLoop len init body = externalize $ forLoop' len (internalize init) body' where body' i st = internalize $ body i (externalize st) arrLength :: Type a => Data [a] -> Data Length arrLength = function "arrLength" Prelude.length getIx :: Type a => Data [a] -> Data Index -> Data a getIx arr ix = function2 "getIx" eval arr ix where eval as i | i >= len || i < 0 = error "getIx: index out of bounds" | otherwise = as !! i where len = Prelude.length as max :: (Type a, Ord a) => Data a -> Data a -> Data a max = function2 "max" Prelude.max min :: (Type a, Ord a) => Data a -> Data a -> Data a min = function2 "min" Prelude.min