module Transformations.Lift (lift) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>))
#endif
import Control.Arrow (first)
import qualified Control.Monad.State as S (State, runState, gets, modify)
import Data.List
import qualified Data.Map as Map (Map, empty, insert, lookup)
import qualified Data.Set as Set (toList, fromList, unions)
import Curry.Base.Ident
import Curry.Base.Position (Position)
import Curry.Syntax
import Base.Expr
import Base.Messages (internalError)
import Base.SCC
import Base.Types
import Env.Value
lift :: ValueEnv -> Module -> (Module, ValueEnv)
lift tyEnv (Module ps m es is ds) = (lifted, valueEnv s')
where
(ds', s') = S.runState (mapM (absDecl "" []) ds) initState
initState = LiftState m tyEnv Map.empty
lifted = Module ps m es is $ concatMap liftFunDecl ds'
type AbstractEnv = Map.Map Ident (QualIdent, [Ident])
data LiftState = LiftState
{ moduleIdent :: ModuleIdent
, valueEnv :: ValueEnv
, abstractEnv :: AbstractEnv
}
type LiftM a = S.State LiftState a
getModuleIdent :: LiftM ModuleIdent
getModuleIdent = S.gets moduleIdent
getValueEnv :: LiftM ValueEnv
getValueEnv = S.gets valueEnv
modifyValueEnv :: (ValueEnv -> ValueEnv) -> LiftM ()
modifyValueEnv f = S.modify $ \ s -> s { valueEnv = f $ valueEnv s }
getAbstractEnv :: LiftM AbstractEnv
getAbstractEnv = S.gets abstractEnv
withLocalAbstractEnv :: AbstractEnv -> LiftM a -> LiftM a
withLocalAbstractEnv ae act = do
old <- getAbstractEnv
S.modify $ \ s -> s { abstractEnv = ae }
res <- act
S.modify $ \ s -> s { abstractEnv = old }
return res
absDecl :: String -> [Ident] -> Decl -> LiftM Decl
absDecl _ lvs (FunctionDecl p f eqs) = FunctionDecl p f
<$> mapM (absEquation lvs) eqs
absDecl pre lvs (PatternDecl p t rhs) = PatternDecl p t <$> absRhs pre lvs rhs
absDecl _ _ d = return d
absEquation :: [Ident] -> Equation -> LiftM Equation
absEquation lvs (Equation p lhs@(FunLhs f ts) rhs) =
Equation p <$> absLhs lhs <*> absRhs (idName f ++ ".") (lvs ++ bv ts) rhs
absEquation _ _ = error "Lift.absEquation: no pattern match"
absLhs :: Lhs -> LiftM Lhs
absLhs (FunLhs f ts) = FunLhs f <$> mapM absPat ts
absLhs _ = error "Lift.absLhs: no simple LHS"
absRhs :: String -> [Ident] -> Rhs -> LiftM Rhs
absRhs pre lvs (SimpleRhs p e _) = simpleRhs p <$> absExpr pre lvs e
absRhs _ _ _ = error "Lift.absRhs: no simple RHS"
absDeclGroup :: String -> [Ident] -> [Decl] -> Expression -> LiftM Expression
absDeclGroup pre lvs ds e = do
m <- getModuleIdent
absFunDecls pre (lvs ++ bv vds) (scc bv (qfv m) fds) vds e
where (fds, vds) = partition isFunDecl ds
absFunDecls :: String -> [Ident] -> [[Decl]] -> [Decl] -> Expression
-> LiftM Expression
absFunDecls pre lvs [] vds e = do
vds' <- mapM (absDecl pre lvs) vds
e' <- absExpr pre lvs e
return (Let vds' e')
absFunDecls pre lvs (fds:fdss) vds e = do
m <- getModuleIdent
env <- getAbstractEnv
tyEnv <- getValueEnv
let
fs = bv fds
fvsRhs = Set.unions
[ Set.fromList (maybe [v] (qfv m . asFunCall) (Map.lookup v env))
| v <- qfv m fds]
fvs = filter (`elem` lvs) (Set.toList fvsRhs)
env' = foldr (bindF fvs) env fs
bindF fvs' f = Map.insert f (qualifyWith m $ liftIdent pre f, fvs')
fs' = filter (\f -> not $ null $ lookupValue f tyEnv) fs
modifyValueEnv $ absFunTypes m pre fvs fs'
withLocalAbstractEnv env' $ do
fds' <- mapM (absFunDecl pre fvs lvs) [d | d <- fds, any (`elem` fs') (bv d)]
e' <- absFunDecls pre lvs fdss vds e
return (Let fds' e')
absFunTypes :: ModuleIdent -> String -> [Ident] -> [Ident]
-> ValueEnv -> ValueEnv
absFunTypes m pre fvs fs tyEnv = foldr abstractFunType tyEnv fs
where tys = map (varType tyEnv) fvs
abstractFunType f tyEnv' =
qualBindFun m (liftIdent pre f)
(length fvs + varArity tyEnv' f)
(polyType (normType ty))
(unbindFun f tyEnv')
where ty = foldr TypeArrow (varType tyEnv' f) tys
normType :: Type -> Type
normType ty = norm (zip (nub $ typeVars ty) [0..]) ty
where
norm vs (TypeVariable n) = case lookup n vs of
Just m -> TypeVariable m
Nothing -> error "Lift.normType"
norm vs (TypeConstructor tc tys) = TypeConstructor tc (map (norm vs) tys)
norm vs (TypeArrow ty1 ty2) = TypeArrow (norm vs ty1) (norm vs ty2)
norm _ tc@(TypeConstrained _ _) = tc
norm _ tsk@(TypeSkolem _) = tsk
absFunDecl :: String -> [Ident] -> [Ident] -> Decl -> LiftM Decl
absFunDecl pre fvs lvs (FunctionDecl p f eqs) =
absDecl pre lvs (FunctionDecl p f' (map (addVars f') eqs))
where
f' = liftIdent pre f
addVars f1 (Equation p1 (FunLhs _ ts) rhs) =
Equation p1 (FunLhs f1 (map VariablePattern fvs ++ ts)) rhs
addVars _ _ = error "Lift.absFunDecl.addVars: no pattern match"
absFunDecl pre _ _ (ForeignDecl p cc ie f ty) =
return $ ForeignDecl p cc ie (liftIdent pre f) ty
absFunDecl _ _ _ _ = error "Lift.absFunDecl: no pattern match"
absExpr :: String -> [Ident] -> Expression -> LiftM Expression
absExpr _ _ l@(Literal _) = return l
absExpr pre lvs var@(Variable v)
| isQualified v = return var
| otherwise = do
getAbstractEnv >>= \env -> case Map.lookup (unqualify v) env of
Nothing -> return var
Just v' -> absExpr pre lvs (asFunCall v')
absExpr _ _ c@(Constructor _) = return c
absExpr pre lvs (Apply e1 e2) = Apply <$> absExpr pre lvs e1
<*> absExpr pre lvs e2
absExpr pre lvs (Let ds e) = absDeclGroup pre lvs ds e
absExpr pre lvs (Case r ct e bs) = Case r ct <$> absExpr pre lvs e
<*> mapM (absAlt pre lvs) bs
absExpr pre lvs (Typed e ty) = flip Typed ty <$> absExpr pre lvs e
absExpr _ _ e = internalError $ "Lift.absExpr: " ++ show e
absAlt :: String -> [Ident] -> Alt -> LiftM Alt
absAlt pre lvs (Alt p t rhs) = Alt p t <$> absRhs pre (lvs ++ bv t) rhs
absPat :: Pattern -> LiftM Pattern
absPat v@(VariablePattern _) = return v
absPat l@(LiteralPattern _) = return l
absPat (ConstructorPattern c ps) = ConstructorPattern c <$> mapM absPat ps
absPat (AsPattern v p) = AsPattern v <$> absPat p
absPat (FunctionPattern f ps) = do
getAbstractEnv >>= \env -> case Map.lookup (unqualify f) env of
Nothing -> FunctionPattern f <$> mapM absPat ps
Just (f', vs) -> (FunctionPattern f' . (map VariablePattern vs ++))
<$> mapM absPat ps
absPat p = error $ "Lift.absPat: " ++ show p
liftFunDecl :: Decl -> [Decl]
liftFunDecl (FunctionDecl p f eqs) = FunctionDecl p f eqs' : concat dss'
where (eqs', dss') = unzip $ map liftEquation eqs
liftFunDecl d = [d]
liftVarDecl :: Decl -> (Decl, [Decl])
liftVarDecl (PatternDecl p t rhs) = (PatternDecl p t rhs', ds')
where (rhs', ds') = liftRhs rhs
liftVarDecl ex@(FreeDecl _ _) = (ex, [])
liftVarDecl _ = error "Lift.liftVarDecl: no pattern match"
liftEquation :: Equation -> (Equation, [Decl])
liftEquation (Equation p lhs rhs) = (Equation p lhs rhs', ds')
where (rhs', ds') = liftRhs rhs
liftRhs :: Rhs -> (Rhs, [Decl])
liftRhs (SimpleRhs p e _) = first (simpleRhs p) (liftExpr e)
liftRhs _ = error "Lift.liftRhs: no pattern match"
liftDeclGroup :: [Decl] -> ([Decl],[Decl])
liftDeclGroup ds = (vds', concat (map liftFunDecl fds ++ dss'))
where (fds , vds ) = partition isFunDecl ds
(vds', dss') = unzip $ map liftVarDecl vds
liftExpr :: Expression -> (Expression, [Decl])
liftExpr l@(Literal _) = (l, [])
liftExpr v@(Variable _) = (v, [])
liftExpr c@(Constructor _) = (c, [])
liftExpr (Apply e1 e2) = (Apply e1' e2', ds1 ++ ds2)
where (e1', ds1) = liftExpr e1
(e2', ds2) = liftExpr e2
liftExpr (Let ds e) = (mkLet ds' e', ds1 ++ ds2)
where (ds', ds1) = liftDeclGroup ds
(e' , ds2) = liftExpr e
liftExpr (Case r ct e alts) = (Case r ct e' alts', concat $ ds' : dss')
where (e' ,ds' ) = liftExpr e
(alts',dss') = unzip $ map liftAlt alts
liftExpr (Typed e ty) = (Typed e' ty, ds) where (e', ds) = liftExpr e
liftExpr _ = internalError "Lift.liftExpr"
liftAlt :: Alt -> (Alt, [Decl])
liftAlt (Alt p t rhs) = (Alt p t rhs', ds') where (rhs', ds') = liftRhs rhs
isFunDecl :: Decl -> Bool
isFunDecl (FunctionDecl _ _ _) = True
isFunDecl (ForeignDecl _ _ _ _ _) = True
isFunDecl _ = False
asFunCall :: (QualIdent, [Ident]) -> Expression
asFunCall (f, vs) = apply (Variable f) (map mkVar vs)
mkVar :: Ident -> Expression
mkVar v = Variable $ qualify v
mkLet :: [Decl] -> Expression -> Expression
mkLet ds e = if null ds then e else Let ds e
apply :: Expression -> [Expression] -> Expression
apply = foldl Apply
simpleRhs :: Position -> Expression -> Rhs
simpleRhs p e = SimpleRhs p e []
varArity :: ValueEnv -> Ident -> Int
varArity tyEnv v = case lookupValue v tyEnv of
[Value _ a _] -> a
_ -> internalError $ "Lift.varArity: " ++ show v
varType :: ValueEnv -> Ident -> Type
varType tyEnv v = case lookupValue v tyEnv of
[Value _ _ (ForAll _ ty)] -> ty
_ -> internalError $ "Lift.varType: " ++ show v
liftIdent :: String -> Ident -> Ident
liftIdent prefix x = renameIdent (mkIdent $ prefix ++ showIdent x) $ idUnique x