-- Lambda Lifting for GHC Core
-- The algorithm is due to Danvy and Schultz and described
-- in "Lambda-Lifting in Quadratic Time" - BRICS RS-02-30.
-- The main function applies 3 transformations in turn:
-- 1. Naming anonymous functions
-- 2. Parameter lifting: closing all function definitions
-- 3. Block floating: moving the function definitions to the top level
module LambdaLifting (lambdaLift) where
import Maybe
import List(sortBy)
import Monad
import Control.Monad.State
import Data.Graph
import Data.Tree
import Data.Set
import ExternalCore
--import Observe
lambdaLift :: Module -> Module
lambdaLift m = bfModule (plModule (nameModule m))
-- Env definitions
varOfBind :: Bind -> Var
varOfBind (Vb (v, ty)) = v
varOfBind (Tb (v, k)) = v
instance Eq Bind where
b1 == b2 = (varOfBind b1) == (varOfBind b2)
instance Ord Bind where
compare b1 b2 = compare (varOfBind b1) (varOfBind b2)
orderBind (Tb _) (Vb _) = GT
orderBind (Vb _) (Tb _) = LT
orderBind _ _ = LT
bindsLk :: [(Var, Set Var)] -> Var -> Set Var
bindsLk l x = fromMaybe emptySet (lookup x l)
data Env = Env { env_binds :: [(Var, Set Var)],
env_vtypes :: [(Var, Bind)] }
empty_env = Env { env_binds = [], env_vtypes = [] }
envLk :: Env -> Var -> Set Var
envLk e x = bindsLk (env_binds e) x
envConcatBinds :: Env -> [(Var, Set Var)] -> Env
envConcatBinds s l = s { env_binds = (env_binds s) ++ l }
envAddBind :: Env -> Bind -> Env
envAddBind s b = s { env_vtypes = (varOfBind b,b):(env_vtypes s) }
envTypeOfBind s (Vb (v,t)) =
case (fromMaybe (Vb (v,t)) (lookup v (env_vtypes s))) of
Vb (v,t) -> t
Tb (v,k) -> error "typeOfBind"
envTypeOfBind s (Tb (v,k)) = error "typeOfBind"
envBindOfVar :: Env -> Var -> Bind
envBindOfVar s v = case lookup v (env_vtypes s) of
Just b -> b
Nothing -> error "envBindOfVar"
-- Nothing -> Vb (v, tWild)
-- Parameter lifting
plModule :: Module -> Module
plModule (Module name tdefs vdefgs) = Module name tdefs vdefgs'
where vdefgs' = foldr filterVdefg [] (map (plVdefg empty_env) vdefgs)
filterVdefg (Rec []) l = l
filterVdefg v l = v:l
plVdefg :: Env -> Vdefg -> Vdefg
plVdefg s (Rec vdefs) = Rec (foldr plAndFilterVdef [] vdefs)
where plAndFilterVdef ((_, 'z':'d':'g':s),_,_) l = l
plAndFilterVdef vdef l = (plVdef s vdef):l
plVdefg s (Nonrec ((_, 'z':'d':'g':st),_,_)) = Rec []
plVdefg s (Nonrec vdef) = Nonrec (plVdef s vdef)
plVdef :: Env -> Vdef -> Vdef
plVdef s (qvar, ty, exp) = applySolutionToVdef s (qvar, ty, plExp s exp)
plExp :: Env -> Exp -> Exp
plExp s e = case e of
Var qvar -> applySolutionToExp s (Var qvar)
Dcon qdcon -> Dcon qdcon
Lit l -> Lit l
App e1 e2 -> App e1' e2'
where e1' = plExp s e1
e2' = plExp s e2
Appt e ty -> Appt e' ty where e' = plExp s e
Lam bind e -> Lam bind (plExp (envAddBind s bind) e)
Let vdefg e -> Let vdefg' e'
where (vv, vf) = splitVars vdefg
funs = mkSet (map fst vf)
vfs = fvMapVdefg vf vdefg
g = foldl add_edges [] vf
where add_edges l (f,b) =
let p = setToList
(intersect funs (bindsLk vfs f))
edges = foldl (\ l -> \ h -> (f,f,[h]):l) [] p
in edges ++ l
g' = stronglyConnCompR g
succAssoc :: [SCC (Var, Var, [Var])] -> Var -> Set Var
succAssoc l f = case l of
[] -> emptySet
(AcyclicSCC (g,h,l)):r | f == g -> mkSet l
(CyclicSCC l):r -> look l
where look :: [(Var,Var,[Var])] -> Set Var
look [] = succAssoc r f
look ((g,h,l):r) | g == f = mkSet l
look (h:r) = look r
h:r -> succAssoc r f
propagate :: [[Var]] -> [(Var, Set Var)]
propagate l = case l of
[] -> []
l:r -> res ++ (propagate r)
where res = map (\f -> (f, minusSet vars (fvars f))) l
vars = union u1 u2
u1 = unionManySets (map (bindsLk vfs) l)
u2 = unionManySets (map (succAssoc g') l)
fvars = \f -> mkSet (fromMaybe [] (lookup f vf))
redG :: [SCC (Var, Var, [Var])] -> [[Var]]
redG l = case l of
[] -> []
(AcyclicSCC (g,h,l)):r -> [[g]]
(CyclicSCC l):r -> (enumG l):(redG r)
where enumG :: [(Var,Var,[Var])] -> [Var]
enumG [] = []
enumG ((g,h,l):r) = g:(enumG r)
lf = redG g'
s' = envConcatBinds s vfs
vdefg' = plVdefg s' vdefg
e' = plExp s' e
Case e vbind alts -> Case e' vbind alts'
where e' = plExp s' e
alts' = map (plAlt s') alts
s' = envAddBind s (Vb vbind)
Coerce ty e -> Coerce ty e' where e' = plExp s e
Note str e -> Note str e' where e' = plExp s e
External str ty -> External str ty
plAlt :: Env -> Alt -> Alt
plAlt s alt = case alt of
(Acon qdc tbs vbs e) -> Acon qdc tbs vbs e'
where e' = plExp (foldl (\s -> \vb -> envAddBind s (Vb vb)) s vbs) e
(Alit l e) -> Alit l e' where e' = plExp s e
(Adefault e) -> Adefault e' where e' = plExp s e
applySolutionToVdef :: Env -> Vdef -> Vdef
applySolutionToVdef s ((mName,var), ty, Lam b e) = ((mName,var), ty', e')
where addType t1 (Vb (v,t2)) = Tapp (Tapp (Tcon tcArrow) t2) t1
addType t (Tb (v,k)) = Tforall (v,k) t
ty' = foldl addType ty lo
e' = foldl (\ e -> \ b -> Lam b e) (Lam b e) lo
lo = sortBy orderBind (map (envBindOfVar s) lv)
lv = setToList (envLk s var)
applySolutionToVdef s def = def
applySolutionToExp :: Env -> Exp -> Exp
applySolutionToExp s (Var (mName, v)) = e'
where e' = if mName == "" then app else (Var (mName, v))
app = foldr makeApp (Var (mName, v)) lo
makeApp b e =
case b of
Vb (v,t) -> App e (Var (mName, v))
Tb (v,k) -> Appt e (Tvar v)
lo = sortBy orderBind (map (envBindOfVar s) lv)
lv = setToList (envLk s v)
applySolutionToExp s e = e
-- Block floating
bfModule :: Module -> Module
bfModule (Module name tdefs vdefgs) = Module name tdefs vdefgs'
where vdefgs' = map bfTopLevelVdefg vdefgs
bfTopLevelVdefg :: Vdefg -> Vdefg
bfTopLevelVdefg (Rec vdefs) = Rec vdefs'
where vdefs' = unionM (map (\ (v,vl) -> v:vl) (map bfVdef vdefs))
unionM l = loop [] l
where loop l [] = l
loop l (h:t) = loop (l++h) t
bfTopLevelVdefg (Nonrec vdef) = vdefg
where vdefg =
case vl of
[] -> Nonrec v
_ -> Rec (v:vl)
(v,vl) = bfVdef vdef
bfVdefg :: Vdefg -> (Vdefg, [Vdef])
bfVdefg (Rec vdefs) = (vdefs', vbf)
where vdefs' = case vl of
[] -> Rec []
[x] -> Nonrec x
_ -> Rec vl
(vl, vbf)= foldr bfLam ([], l) vdl
(vdl, l) = foldr (\ (v,vl) -> \ (l1,l2) -> (v:l1,vl++l2)) ([],[]) r
r = map bfVdef vdefs
bfVdefg (Nonrec vdef) = (vdefg, vbf)
where vdefg = case vl of
[] -> Rec []
[x] -> Nonrec x
_ -> error "bfVdefg"
(vl,vbf) = bfLam vdef' ([], l)
(vdef', l) = bfVdef vdef
bfLam :: Vdef -> ([Vdef],[Vdef]) -> ([Vdef], [Vdef])
bfLam lam (ld, lbf) =
case lam of
((m,v),ty,Lam b e) ->
case v of
'p':'r':'o':'p':'z':'u':_ -> (lam:ld, lbf)
'l':'e':'m':'m':'a':'z':'u':_ -> (lam:ld, lbf)
_ -> (ld, ((m,v),ty,Lam b e):lbf)
e -> (e:ld, lbf)
bfVdef :: Vdef -> (Vdef, [Vdef])
bfVdef (qv,ty,e) = ((qv,ty,e'),l) where (e',l) = bfExp e
bfExp :: Exp -> (Exp, [Vdef])
bfExp e = case e of
Var qv -> (Var qv, [])
Dcon qd -> (Dcon qd, [])
Lit l -> (Lit l, [])
App e1 e2 -> (App e1' e2', l1++l2)
where (e1', l1) = bfExp e1
(e2', l2) = bfExp e2
Appt e t -> (Appt e' t, l) where (e', l) = bfExp e
Lam b e -> (Lam b e', l) where (e', l) = bfExp e
Let vdefg e -> (new_e, l1++l2)
where new_e = case vdefg' of
Rec [] -> e'
_ -> Let vdefg' e'
(vdefg', l1) = bfVdefg vdefg
(e', l2) = bfExp e
Case e bind alts -> (Case e' bind alts', l1++l2)
where (e', l1) = bfExp e
(alts', l2) = foldr bfAlt ([],[]) alts
Coerce ty e -> (Coerce ty e', l) where (e', l) = bfExp e
Note s e -> (Note s e', l) where (e', l) = bfExp e
External s ty -> (External s ty, [])
bfAlt :: Alt -> ([Alt],[Vdef]) -> ([Alt],[Vdef])
bfAlt a (la, lv) = case a of
Acon qdcon tbs vbs e -> ((Acon qdcon tbs vbs e'):la, l++lv)
where (e', l) = bfExp e
Alit lit e -> ((Alit lit e'):la, l++lv) where (e', l) = bfExp e
Adefault e -> ((Adefault e'):la, l++lv) where (e', l) = bfExp e
-- Some intermediate functions
-- this ones take a let definition and returns ([var,ty], [(fun, [arg])])
-- var is the list of non-lambda names defined by the let
-- fun is the list of lambda names defined, arg being the bindings made.
splitVars :: Vdefg -> ([(Var,Ty)], [(Var, [Var])])
splitVars vdefg = case vdefg of
Rec vdefs -> foldl splitVdef ([], []) vdefs
Nonrec vdef -> splitVdef ([], []) vdef
splitVdef :: ([(Var,Ty)], [(Var, [Var])]) -> Vdef ->
([(Var,Ty)], [(Var, [Var])])
splitVdef (lv,lf) vdef = case vdef of
((_,v), ty, Lam bind e) -> (lv, (fbind:lf))
where fbind = splitExp (v,[]) (Lam bind e)
((_,v), ty, _) -> (((v,ty):lv), lf)
splitExp :: (Var, [Var]) -> Exp -> (Var, [Var])
splitExp (f,larg) e = case e of
Lam bind e -> splitExp (f,(varOfBind bind):larg) e
_ -> (f,larg)
-- Free variables
fvMapVdefg :: [(Var, [Var])] -> Vdefg -> [(Var, Set Var)]
fvMapVdefg fv vdefg = case vdefg of
Rec vdefs -> foldl (findFv fv) [] vdefs
Nonrec vdef -> findFv fv [] vdef
findFv :: [(Var, [Var])] -> [(Var, Set Var)] -> Vdef -> [(Var, Set Var)]
findFv fv l ((mName, v), ty, e) = l'
where l' = case lookup v fv of
Nothing -> l
Just _ -> fvv:l
fvv = (v, fvExp (fvEmptyEnv mName v) e)
data FvEnv = FE { bound :: Set Var,
free :: Set Var,
modName :: String }
fvEmptyEnv :: String -> Var -> FvEnv
fvEmptyEnv mName fName = FE { bound = mkSet [fName],
free = emptySet,
modName = mName }
returnFree :: FvEnv -> Set Var
returnFree fe = free fe
fvBind :: FvEnv -> Bind -> FvEnv
fvBind fe (Vb (v,t)) = fe { bound = addToSet (bound fe) v }
fvBind fe (Tb (v,k)) = fe { bound = addToSet (bound fe) v }
fvCheckBind :: FvEnv -> Qual Var -> Set Var
fvCheckBind fe (mName, vName) = returnFree fe'
where fe' = if isFree then fe { free = addToSet (free fe) vName }
else fe
isFree = isLocal && (not (elementOf vName (bound fe)))
isLocal = mName == (modName fe)
fvBindVdef :: FvEnv -> Vdef -> FvEnv
fvBindVdef fe ((mName, vName), ty, e) = fvBind fe (Vb (vName, ty))
fvVdefg :: FvEnv -> Vdefg -> FvEnv
fvVdefg fe (Rec vdefs) = foldl fvVdef (foldl fvBindVdef fe vdefs) vdefs
fvVdefg fe (Nonrec vdef) = fvVdef (fvBindVdef fe vdef) vdef
fvVdef :: FvEnv -> Vdef -> FvEnv
fvVdef fe ((mName, vName), ty, e) = fe { free = fvExp fe e }
fvExp :: FvEnv -> Exp -> Set Var
fvExp fe e = case e of
Var qvar -> fvCheckBind fe qvar
Dcon dcon -> returnFree fe
Lit l -> returnFree fe
App e1 e2 -> fv
where fv = union fv1 fv2
fv1 = fvExp fe e1
fv2 = fvExp fe e2
Appt e t -> fv
where fv = union fv1 fv2
fv1 = fvExp fe e
fv2 = fvTy fe t
Lam bind e -> fvExp (fvBind fe bind) e
Let vdefg e -> fv
where fv = fvExp fel e
fel = fvVdefg fe vdefg
Case e (var,ty) alts -> unionManySets (fve:fty:fel)
where fel = map (fvAlt (fvBind fe (Vb (var, ty)))) alts
fty = fvTy fe ty
fve = fvExp fe e
Coerce ty e -> fvExp fe e
Note str e -> fvExp fe e
External str ty -> returnFree fe
fvAlt :: FvEnv -> Alt -> Set Var
fvAlt fe alt = case alt of
Acon qdcon tbinds vbinds e -> fv
where fv = fvExp fb e
fb = foldl (\fe -> \b -> fvBind fe (Vb b)) ft vbinds
ft = foldl (\fe -> \b -> fvBind fe (Tb b)) fe tbinds
Alit l e -> fvExp fe e
Adefault e -> fvExp fe e
fvTy :: FvEnv -> Ty -> Set Var
fvTy fe ty = case ty of
Tvar var -> fvCheckBind fe ("", var)
Tcon qtcon -> fvCheckBind fe qtcon
Tapp ty1 ty2 -> fv
where fv = union fv1 fv2
fv1 = fvTy fe ty1
fv2 = fvTy fe ty2
Tforall tbind ty -> fvTy (fvBind fe (Tb tbind)) ty
-- Naming anonymous functions
nameModule :: Module -> Module
nameModule (Module name tdefs vdefgs) = (Module name tdefs vdefgs')
where vdefgs' = map nTopLevelVdefg vdefgs
nTopLevelVdefg :: Vdefg -> Vdefg
nTopLevelVdefg (Rec vs) = Rec (map nTopLevelVdef vs)
nTopLevelVdefg (Nonrec v) = Nonrec (nTopLevelVdef v)
nTopLevelVdef :: Vdef -> Vdef
nTopLevelVdef ((m,v),t,e) = case v of
'p':'r':'o':'p':'z':'u':_ -> ((m,v),t,e)
'l':'e':'m':'m':'a':'z':'u':_ -> ((m,v),t,e)
_ -> ((m,v), t, nTopLevelExp v e)
-- where a = evalState (nTopLevelExp e) (v, 0)
nTopLevelExp :: String -> Exp -> Exp
nTopLevelExp v e = case e of
Lam b e -> Lam b (nTopLevelExp v e)
_ -> evalState (nExp e) (v,0)
nVdefg :: Vdefg -> State (String, Int) Vdefg
nVdefg (Rec vs) = mapM nVdef vs >>= \a -> return (Rec a)
nVdefg (Nonrec v) = nVdef v >>= \a -> return (Nonrec a)
nVdef :: Vdef -> State (String, Int) Vdef
nVdef ((m,v),t,e) = do
a <- nExp e
return ((m,v),t,a)
nExp :: Exp -> State (String, Int) Exp
nExp e = case e of
Var q -> return (Var q)
Dcon q -> return (Dcon q)
Lit l -> return (Lit l)
App (Lam vb e1) e2 ->
nExp e1 >>= \a -> nExp e2 >>= \b ->
mapState
(\ ((a,b),(f,i)) ->
(Let (Nonrec (("", f ++ "_" ++ show i),
(Tapp(Tapp(Tcon tcArrow) (typeOfBind vb))
(fst (typeExp [] a))),
(Lam vb a)))
(App (Var ("", f ++ "_" ++ show i)) b), (f,i+1)))
(return (a,b))
App (Var ("Property", "forAll")) e' -> return e
App e1 e2 -> nExp e1 >>= (\a -> nExp e2 >>= (\b -> return (App a b)))
Appt e t -> nExp e >>= (\a -> return (Appt a t))
-- Lam b e -> nExp e >>= \a -> return (Lam b a)
Lam b e ->
nExp e >>= \a ->
mapState
(\ (a,(f,i)) ->
(Let (Nonrec (("", f ++ "_" ++ show i),
(Tapp(Tapp(Tcon tcArrow) (typeOfBind b))
(fst (typeExp [] a))),
(Lam b a)))
(Var ("", f ++ "_" ++ show i)), (f,i+1)))
(return a)
Let v e -> nVdefg v >>= (\a -> nExp e >>= (\b -> return (Let a b)))
Case e b as -> do a <- nExp e
c <- mapM nAlt as
return (Case a b c)
Coerce t e -> nExp e >>= \a -> return (Coerce t a)
Note s e -> nExp e >>= \a -> return (Note s a)
External s t -> return (External s t)
nAlt :: Alt -> State (String, Int) Alt
nAlt a = case a of
Acon q ts vs e -> nExp e >>= \a -> return (Acon q ts vs a)
Alit l e -> nExp e >>= \a -> return (Alit l a)
Adefault e -> nExp e >>= \a -> return (Adefault a)
typeOfBind (Vb (v,t)) = t
typeOfBind (Tb (v,k)) = error "typeOfBind"
--typeOfBind (Tb (v,k)) = tWild
-- Putting types back
instance Eq Ty where
(==) (Tvar v1) (Tvar v2) = (v1 == v2)
(==) (Tcon c1) (Tcon c2) = (c1 == c2)
(==) _ _ = False
-- (==) (Tapp t1 t2) (Tapp t1' t2') = (t1 == t1') && (t2 == t2')
--tWild = error "LambdaLifting.tWild: unknown type"
--tWild = observe "hej" $ Tvar "?"
tWild = Tvar "?"
tLk env (mName, var) =
case mName of
"" -> fromMaybe tWild (lookup var env)
"DataziTuple" ->
case var of
"Z2T" -> Tcon ("", "Pair")
_ -> tWild
supType (Tvar "?") t = t
supType t (Tvar "?") = t
supType (Tapp t1 t2) (Tapp t1' t2') = Tapp (supType t1 t1') (supType t2 t2')
supType t1 t2 | t1 == t2 = t1
| otherwise = t1
-- **PJ: why is supType t1 t2 always t1, but still written as two cases?
typeModule (Module name tdefs vdefgs) = (Module name tdefs vdefgs')
where vdefgs' = map (\v -> snd (typeVdefg [] v)) vdefgs
typeVdefg env vdefg = case vdefg of
Rec vs -> (env', Rec vs')
where (env', vs') = foldr loopVdef (env, []) vs
loopVdef v (e,vs) = (e',v':vs) where (e',v') = typeVdef e v
Nonrec v -> (env', Nonrec v') where (env', v') = typeVdef env v
typeVdef env ((mName, v), t, e) = (env', ((mName, v), nt, e'))
where env' = (v,nt):env
nt = supType t t'
(t',e') = typeExp ((v,t):env) e
typeExp env e = case e of
Var qv -> (tLk env qv, e)
Dcon qd -> (tLk env qd, e)
Lit l -> (typeLit l, e)
App e1 e2 -> (tApp, App e1' e2')
where (t1, e1') = typeExp env e1
(t2, e2') = typeExp env e2
tApp = case t1 of
Tapp(Tapp(Tcon arr) targ) tres | (arr == tcArrow) -> tres
Tforall tb ty -> ty
_ -> t1
Appt e1 t -> (tApp, Appt e1' t)
where (te, e1') = typeExp env e1
tApp = Tapp te t
Lam (Vb (v,t)) e1 -> (tLam, Lam (Vb (v,t')) e')
where t' = supType t (fromMaybe tWild (lookup v env))
(te,e') = typeExp ((v,t'):env) e1
tLam = Tapp(Tapp(Tcon tcArrow) t') te
Lam (Tb tb) e1 -> (tLam, Lam (Tb tb) e')
where (te,e') = typeExp env e1
tLam = Tforall tb te
Let vdefg e1 -> (te, Let vdefg' e')
where (te, e') = typeExp env' e1
(env', vdefg') = typeVdefg env vdefg
Case e1 (v,t) alts -> (tc, Case e' (v,t'') alts')
where (tc, alts') = typeAlts ((v,t'):env) alts
t'' = supType (supType (fromMaybe tWild (lookup v env)) t) t'
(t', e') = typeExp env e1
Coerce t e1 -> (t, Coerce t e') where (t', e') = typeExp env e1
Note s e1 -> (t, Note s e') where (t, e') = typeExp env e1
External s t -> (t, External s t)
typeAlts env as = (t, as')
where (t, as') = foldr (\ (t2,a) (t1,as) -> (supType t1 t2, a:as))
(tWild, []) tas
tas = map (tAlt env) as
tAlt env alt = case alt of
Acon qd tb vb e -> (t, Acon qd tb vb e')
where (t,e') = typeExp (vb ++ env) e
Alit l e -> (t, Alit l e') where (t,e') = typeExp env e
Adefault e -> (t, Adefault e') where (t,e') = typeExp env e
typeLit lit = case lit of
Lint _ t -> t
Lrational _ t -> t
Lchar _ t -> t
Lstring _ t -> t