module ObsLab where

import Obsidian
import Obsidian.CodeGen.CUDA
import Obsidian.Run.CUDA.Exec 

import Prelude hiding (zipWith)
import qualified Prelude as P
import Control.Monad.State

import qualified Data.Vector.Storable as V 
import Data.Word 


----------------------------------------------------------------------
-- TASK 3 Vector Addition 
----------------------------------------------------------------------

------------------------------------------------------------
-- vadd block-local computation
-- - Look at the "inc" example from the lecture slides 
------------------------------------------------------------

vaddLocal :: SPull EFloat
          -> SPull EFloat
          -> SPull EFloat
vaddLocal = undefined 

------------------------------------------------------------
-- Distribute and replicate the vadd computation over blocks
-- - Split the array into parts of 128 elements each
-- - Look at the "inc" example from the lecture slides 
------------------------------------------------------------

vadd :: DPull EFloat -> DPull EFloat -> DPush Grid EFloat
vadd a b = asGrid $ undefined 
 

------------------------------------------------------------
-- Launch the vadd computation on the GPU
------------------------------------------------------------
    
launchVadd :: IO () 
launchVadd =
  withCUDA $
  do
    -- capture: compiles the Obsidian code all the
    --          way to GPU executable format.
    --  - The parameter "128" indicates the number
    --    of threads to generate code for.
    --    In this case 64 "real" CUDA threads are used
    --    to perform the 128 additions in vadd.
    --    Experiment with this number. 
    kern <- capture 64 vadd

    -- Generate input data and allocate arrays in GPU DRAM
    -- This is an example. For timing larger arrays
    -- and more blocks should be used. 
    useVector (V.fromList [0..255]) $ \a ->
      useVector (V.fromList (P.reverse [0..255])) $ \b ->
        withVector 256 $ \o ->
        do
          fill o 0

          -- Launch 2 Blocks computing "kern" with
          -- arrays a and b as input.
          -- Put output in o. 
          o <== (2,kern) <> a <> b

          r <- copyOut o
          lift $ putStrLn $ show r


----------------------------------------------------------------------
-- TASK 4 (Reduction) 
----------------------------------------------------------------------

------------------------------------------------------------
-- Reduction (sum or generalized) 
------------------------------------------------------------

sumLocal :: SPull EFloat
         -> Program Block (SPush Block EFloat) 
sumLocal arr
  | len arr == 1 = undefined
  | otherwise = undefined 

-- alternative
       
sumLocal' :: SPull EFloat
          -> SPush Block EFloat 
sumLocal' arr = execBlock $ body arr
  where body arr 
          | len arr == 1 = undefined 
          | otherwise = undefined 


------------------------------------------------------------
-- Perform many parallel reductions. One per block 
------------------------------------------------------------
sums :: DPull EFloat -> DPush Grid EFloat
sums arr = asGrid $ undefined

-- alternative 

sums' :: DPull EFloat -> DPush Grid EFloat
sums' arr = asGrid $ undefined 


------------------------------------------------------------
-- Launch the sums computation on the GPU
------------------------------------------------------------

launchSums :: IO () 
launchSums =
  withCUDA $
  do
    kern <- capture 64 sums 

    -- generate input data and allocate arrays in GPU DRAM
    useVector (V.fromList [0..255]) $ \a ->
        withVector 2 $ \o ->
        do
          fill o 0

          -- Launch 2 Blocks computing "kern" with
          -- arrays a and b as input.
          -- Put output in o. 
          o <== (2,kern) <> a 

          r <- copyOut o
          lift $ putStrLn $ show r

    
----------------------------------------------------------------------
-- TASK 5 (Dot Product) 
----------------------------------------------------------------------
  
  -- This task combines a reduction and an element wise operation. 
  -- The reduction needed is sum, but the elementwise operation
  -- we need is vector "products".

prodLocal :: SPull EFloat
          -> SPull EFloat
          -> SPull EFloat
prodLocal = undefined 

products :: DPull EFloat -> DPull EFloat -> DPush Grid EFloat
products a b = asGrid $ undefined 

------------------------------------------------------------
-- perform many dot products in parallel, one per block. 
------------------------------------------------------------

-- One alternative (1), create a dotProds kernel.
-- This should combine prodLocal and sumLocal. 
dotProds :: DPull EFloat -> DPull EFloat -> DPush Grid EFloat
dotProds a1 a2 = asGrid $ unefined 
  where
    body :: SPull EFloat -> SPull EFloat -> SPush Block EFloat
    body a b = undefined 
    -- (hint) -- body a b = execBlock $ do ... 
     
-- Another alternative (2) is to run the "products" and "sums" kernels
-- one after another as separate kernel launches. 

------------------------------------------------------------
-- Launch Alternative 2
------------------------------------------------------------

launchAlt2 :: IO () 
launchAlt2 =
  withCUDA $
  do
    prod_k <- capture 64 products
    sums_k <- capture 64 sums 

    -- generate input data and allocate arrays in GPU DRAM
    useVector (V.fromList [0..255]) $ \a ->
      useVector (V.fromList (P.reverse [0..255])) $ \b ->
        withVector 256 $ \tmp ->  
          withVector 2 $ \o ->
          do
            fill o 0
            
            tmp <== (2,prod_k) <> a <> b
            o   <== (2,sums_k) <> tmp 

            r <- copyOut o
            lift $ putStrLn $ show r

------------------------------------------------------------
-- Launch Alternative 1
------------------------------------------------------------

launchAlt1 :: IO () 
launchAlt1 =
  withCUDA $
  do
    dotp_k <- capture 64 dotProds
    
    -- generate input data and allocate arrays in GPU DRAM
    useVector (V.fromList [0..255]) $ \a ->
      useVector (V.fromList (P.reverse [0..255])) $ \b ->
        withVector 2 $ \o ->
        do
          fill o 0
            
          o   <== (2,dotp_k) <> a <> b 

          r <- copyOut o
          lift $ putStrLn $ show r



----------------------------------------------------------------------
-- Task 6 
----------------------------------------------------------------------

 -- This task is very open.
 -- * Build on one of the previous tasks by parameterizing
 --   and generalizing.
 -- * Implement anything. There are ideas in the LAB description. 
  


----------------------------------------------------------------------
-- Main (example) 
----------------------------------------------------------------------
          
main = launchVadd