module Paradox.AnalysisTypes ( types ) where {- Paradox -- Copyright (c) 2003-2007, Koen Claessen, Niklas Sorensson Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -} import Form import Name import Data.Set( Set ) import qualified Data.Set as S import Data.Map( Map ) import qualified Data.Map as M import Control.Monad.ST ( ST , runST ) import Data.STRef ( STRef , newSTRef , readSTRef , writeSTRef ) import Char ( ord , chr ) import List ( nub ) --------------------------------------------------------------------------------- -- types types :: [Clause] -> Either String ([Type], Clause -> Clause) types cs = runT (inferClauses cs) --------------------------------------------------------------------------------- -- infer inferClauses :: [Clause] -> T s () inferClauses [] = do return () inferClauses (c:cs) = do scope (free c) $ do sequence_ [ inferLit l | l <- c ] inferClauses cs inferLit :: Signed Atom -> T s () inferLit (Pos x) = inferAtom True x inferLit (Neg x) = inferAtom False x inferAtom :: Bool -> Atom -> T s () inferAtom sgn (x :=: y) = do t1 <- inferTerm x t2 <- inferTerm y t1 =:= t2 if sgn then touchEq t1 eq else return () where eq = case (x,y) of (Var _, Var _) -> Full (Var _, _ ) -> Half (_ , Var _) -> Half _ -> Safe inferTerm :: Term -> T s (TypeId s) inferTerm (Var v) = do getVar v inferTerm (Fun f xs) = do (ts,t) <- getFun f (length xs) inferAndUnify f xs ts return t inferAndUnify :: Symbol -> [Term] -> [TypeId s] -> T s () inferAndUnify s xs ts | length xs == length ts = sequence_ [ do t' <- inferTerm x t =:= t' | (x,t) <- xs `zip` ts ] | otherwise = throw ( "Symbol '" ++ show s ++ "' used with different arities " ++ show (length xs) ++ " and " ++ show (length ts) ) --------------------------------------------------------------------------------- -- run runT :: (forall s . T s a) -> Either String ([Type], Clause -> Clause) runT tm = runST (runT' tm) runT' :: T s a -> ST s (Either String ([Type], Clause -> Clause)) runT' (MkT m) = do idfs <- newSTRef 0 preds <- newSTRef M.empty funs <- newSTRef M.empty vars <- newSTRef M.empty m idfs preds funs vars (\s -> return (Left s)) (\_ -> do ps' <- readSTRef preds fs' <- readSTRef funs ps <- sequence [ do ts <- sequence [ do (_,t,_) <- typeInfo t' return t | t' <- ts' ] return (p,ts) | (p,ts') <- M.toList ps' ] fs <- sequence [ do ts <- sequence [ do (_,t,_) <- typeInfo t' return t | t' <- ts' ] (_,t,_) <- typeInfo t' return (f,(ts,t)) | (f,(ts',t')) <- M.toList fs' ] typeIds' <- sequence [ do ((_,eq),t,_) <- typeInfo t' return (t,eq) | t' <- [ t | (_,ts) <- M.toList ps', t <- ts ] ++ [ t | (_,(ts,tr)) <- M.toList fs', t <- tr:ts ] ] let typeIds = S.toList (S.fromList typeIds') names = [ s | i <- [1..] , let s | i <= 26 = name [chr (ord 'A' + i - 1)] | otherwise = name "T" % (i-26) ] typesAndTypeIds = [ ( Type { tname = s , tsize = n , tequal = eq } , t ) | (s,(t,eq)) <- names `zip` typeIds , let fResT = [ (f,length ts) | (f,(ts,t')) <- fs, t == t' ] n = case [ ar | (f,ar) <- fResT, ar > 0 ] of [] -> Just (length fResT `max` 1) _ -> Nothing ] types = [ t | (t,_) <- typesAndTypeIds ] typeIdToType = M.fromList [ (tid,t) | (t,tid) <- typesAndTypeIds ] typeOfId tid = case M.lookup tid typeIdToType of Just t -> t Nothing -> error "Types: no type" predTable = M.fromList [ (n, map typeOfId ts :-> bool) | (n ::: _, ts) <- ps ] funTable = M.fromList [ (n, map typeOfId ts :-> typeOfId t) | (n ::: _, (ts,t)) <- fs ] typeOfPred (p ::: _) = case M.lookup p predTable of Just t -> t Nothing -> error $ "Types: no pred type" typeOfFun (f ::: _) = case M.lookup f funTable of Just t -> t Nothing -> error "Types: no fun type" trans c = map transLit c where ls = c varsLit (Pos a) = varsAtom a varsLit (Neg a) = varsAtom a varsAtom (y :=: Fun f xs) = varsTerms ((y,t):(xs `zip` ts)) where ts :-> t = typeOfFun f varsAtom (Fun f xs :=: y) = varsTerms ((y,t):(xs `zip` ts)) where ts :-> t = typeOfFun f varsAtom (x :=: y) = varsTerms [(x,top),(y,top)] varsTerms [] = [] varsTerms ((Fun f xs,_):xts) = varsTerms ((xs `zip` ts)++xts) where ts :-> _ = typeOfFun f varsTerms ((Var (n ::: _),t):xts) | t == top = varsTerms xts | otherwise = (n,V t) : varsTerms xts varTable = M.fromList (concatMap varsLit ls) transLit (Pos a) = Pos (transAtom a) transLit (Neg a) = Neg (transAtom a) transAtom (x :=: y) = transTerm x :=: transTerm y transTerm (Fun f@(n ::: (_ :-> t)) xs) = Fun f' (map transTerm xs) where ts :-> t' = typeOfFun f f' = n ::: (ts :-> (if t == bool then bool else t')) transTerm (Var v@(n ::: _)) = Var v' where v' = n ::: case M.lookup n varTable of Just t -> t Nothing -> V top return (Right (types, trans)) ) ------------------------------------------------------------------------- -- T monad type TypeInfo = (Int,Equality) data TypeId s = MkTypeId !Int !(STRef s (Either TypeInfo (TypeId s))) instance Eq (TypeId s) where MkTypeId n1 _ == MkTypeId n2 _ = n1 == n2 instance Ord (TypeId s) where MkTypeId n1 _ `compare` MkTypeId n2 _ = n1 `compare` n2 newtype T s a = MkT ( forall b . STRef s Int -- unique name table -> STRef s (Map Symbol [TypeId s]) -- predicate table -> STRef s (Map Symbol ([TypeId s], TypeId s)) -- function table -> STRef s (Map Symbol (TypeId s)) -- variable table -> (String -> ST s b) -> (a -> ST s b) -> ST s b ) -- monad instance Monad (T s) where return x = MkT (\idfs preds funs vars fail ok -> ok x ) MkT m1 >>= k = MkT (\idfs preds funs vars fail ok -> m1 idfs preds funs vars fail (\a -> let MkT m2 = k a in m2 idfs preds funs vars fail ok )) -- primitives getIdfs :: T s (STRef s Int) getIdfs = MkT (\idfs preds funs vars fail ok -> ok idfs) getPreds :: T s (STRef s (Map Symbol [TypeId s])) getPreds = MkT (\idfs preds funs vars fail ok -> ok preds) getFuns :: T s (STRef s (Map Symbol ([TypeId s], TypeId s))) getFuns = MkT (\idfs preds funs vars fail ok -> ok funs) getVars :: T s (STRef s (Map Symbol (TypeId s))) getVars = MkT (\idfs preds funs vars fail ok -> ok vars) throw :: String -> T s a throw s = MkT (\idfs preds funs vars fail ok -> fail s) lift :: ST s a -> T s a lift m = MkT (\_ _ _ _ _ ok -> do a <- m; ok a) -- derived new :: T s (TypeId s) new = do ref <- lift (newSTRef (Left (1,Safe))) idfs <- getIdfs n <- lift $ readSTRef idfs let n' = n + 1 n' `seq` lift (writeSTRef idfs n') return (MkTypeId n ref) typeInfo :: TypeId s -> ST s (TypeInfo, Int, STRef s (Either TypeInfo (TypeId s))) typeInfo (MkTypeId i ref) = do mref <- readSTRef ref case mref of Left inf -> do return (inf,i,ref) Right t -> do (inf,i,c) <- typeInfo t writeSTRef ref (Right (MkTypeId i c)) return (inf,i,c) touchEq :: TypeId s -> Equality -> T s () touchEq t eq' = do ((n,eq),_,ref) <- lift (typeInfo t) lift $ writeSTRef ref (Left (n,eq `max` eq')) getPred :: Symbol -> Int -> T s [TypeId s] getPred p n = do preds <- getPreds ps <- lift (readSTRef preds) case M.lookup p ps of Just ts -> do return ts Nothing -> do ts <- sequence [ new | i <- [1..n] ] lift (writeSTRef preds (M.insert p ts ps)) return ts getFun :: Symbol -> Int -> T s ([TypeId s], TypeId s) getFun f n = do funs <- getFuns fs <- lift (readSTRef funs) case M.lookup f fs of Just (ts,t) -> do return (ts,t) Nothing -> do ts <- sequence [ new | i <- [1..n] ] t <- new lift (writeSTRef funs (M.insert f (ts,t) fs)) return (ts,t) scope :: Set Symbol -> T s a -> T s a scope news tm = do vars <- getVars vs <- lift (readSTRef vars) vts <- sequence [ do t <- new return (v,t) | v <- S.toList news ] lift (writeSTRef vars (vs `M.union` M.fromList vts)) -- order is important a <- tm lift (writeSTRef vars vs) return a getVar :: Symbol -> T s (TypeId s) getVar v = do vars <- getVars vs <- lift (readSTRef vars) case M.lookup v vs of Just t -> do return t Nothing -> do error "var not in scope" (=:=) :: TypeId s -> TypeId s -> T s () MkTypeId i1 ref1 =:= MkTypeId i2 ref2 = lift ( do unify i1 ref1 i2 ref2 return () ) where unify i1 ref1 i2 ref2 | ref1 == ref2 = return (i1,ref1) | otherwise = do mref1 <- readSTRef ref1 case mref1 of Left (n1,eq1) -> do mref2 <- readSTRef ref2 case mref2 of Left (n2,eq2) | n1 < n2 -> do writeSTRef ref1 (Right (MkTypeId i2 ref2)) writeSTRef ref2 (Left (n1+n2,(eq1 `max` eq2))) return (i2,ref2) | otherwise -> do writeSTRef ref2 (Right (MkTypeId i1 ref1)) writeSTRef ref1 (Left (n1+n2,(eq1 `max` eq2))) return (i1,ref1) Right (MkTypeId i2' ref2') -> do (i,ref) <- unify i1 ref1 i2' ref2' writeSTRef ref2 (Right (MkTypeId i ref)) return (i,ref) Right (MkTypeId i1' ref1') -> do (i,ref) <- unify i1' ref1' i2 ref2 writeSTRef ref1 (Right (MkTypeId i ref)) return (i,ref) unifyList :: [TypeId s] -> T s () unifyList [] = return () unifyList (t:ts) = sequence_ [ t' =:= t | t' <- ts ] ------------------------------------------------------------------------- -- the end.