module Base.Typing (Typeable (..)) where
import Control.Monad
import qualified Control.Monad.State as S (State, evalState, gets, modify)
import Curry.Base.Ident
import Curry.Syntax
import Base.Messages (internalError)
import Base.TopEnv
import Base.Types
import Base.TypeSubst
import Base.Utils (foldr2)
import Env.Value ( ValueEnv, ValueInfo (..), lookupValue, qualLookupValue)
data TcState = TcState
{ valueEnv :: ValueEnv
, typeSubst :: TypeSubst
, nextId :: Int
}
type TCM = S.State TcState
getValueEnv :: TCM ValueEnv
getValueEnv = S.gets valueEnv
getTypeSubst :: TCM TypeSubst
getTypeSubst = S.gets typeSubst
modifyTypeSubst :: (TypeSubst -> TypeSubst) -> TCM ()
modifyTypeSubst f = S.modify $ \s -> s { typeSubst = f $ typeSubst s }
getNextId :: TCM Int
getNextId = do
nid <- S.gets nextId
S.modify $ \ s -> s { nextId = succ nid }
return nid
run :: TCM a -> ValueEnv -> a
run m tyEnv = S.evalState m initState
where initState = TcState tyEnv idSubst 0
class Typeable a where
typeOf :: ValueEnv -> a -> Type
instance Typeable Type where
typeOf _ ty = ty
instance Typeable Ident where
typeOf = computeType identType
instance Typeable Pattern where
typeOf = computeType argType
instance Typeable Expression where
typeOf = computeType exprType
instance Typeable Rhs where
typeOf = computeType rhsType
computeType :: (a -> TCM Type) -> ValueEnv -> a -> Type
computeType f tyEnv x = normalize (run doComputeType tyEnv)
where
doComputeType = do
ty <- f x
theta <- getTypeSubst
return (fixTypeVars tyEnv (subst theta ty))
fixTypeVars :: ValueEnv -> Type -> Type
fixTypeVars tyEnv ty = subst (foldr2 bindSubst idSubst tvs tvs') ty
where tvs = filter (>= 0) (typeVars ty)
tvs' = map TypeVariable [n 1,n 2 ..]
n = minimum (0 : concatMap typeVars tys)
tys = [ty1 | (_,Value _ _ (ForAll _ ty1)) <- localBindings tyEnv]
identType :: Ident -> TCM Type
identType x = do
tyEnv <- getValueEnv
instUniv (varType x tyEnv)
litType :: Literal -> TCM Type
litType (Char _ _) = return charType
litType (Int v _) = identType v
litType (Float _ _) = return floatType
litType (String _ _) = return stringType
argType :: Pattern -> TCM Type
argType (LiteralPattern l) = litType l
argType (NegativePattern _ l) = litType l
argType (VariablePattern v) = identType v
argType (ConstructorPattern c ts) = do
tyEnv <- getValueEnv
ty <- instUnivExist (constrType c tyEnv)
tys <- mapM argType ts
unifyList (init (flatten ty)) tys
return (last (flatten ty))
where flatten (TypeArrow ty1 ty2) = ty1 : flatten ty2
flatten ty = [ty]
argType (InfixPattern t1 op t2) =
argType (ConstructorPattern op [t1,t2])
argType (ParenPattern t) = argType t
argType (RecordPattern c fs) = do
tyEnv <- getValueEnv
ty <- liftM arrowBase $ instUnivExist $ constrType c tyEnv
mapM_ (fieldType argType ty) fs
return ty
argType (TuplePattern _ ts)
| null ts = return unitType
| otherwise = liftM tupleType $ mapM argType ts
argType (ListPattern _ ts) = freshTypeVar >>= flip elemType ts
where elemType ty [] = return (listType ty)
elemType ty (t:ts1) =
argType t >>= unify ty >> elemType ty ts1
argType (AsPattern v _) = argType (VariablePattern v)
argType (LazyPattern _ t) = argType t
argType (FunctionPattern f ts) = do
tyEnv <- getValueEnv
ty <- instUniv (funType f tyEnv)
tys <- mapM argType ts
unifyList (init (flatten ty)) tys
return (last (flatten ty))
where flatten (TypeArrow ty1 ty2) = ty1 : flatten ty2
flatten ty = [ty]
argType (InfixFuncPattern t1 op t2) = argType (FunctionPattern op [t1,t2])
exprType :: Expression -> TCM Type
exprType (Literal l) = litType l
exprType (Variable v) = do
tyEnv <- getValueEnv
instUniv (funType v tyEnv)
exprType (Constructor c) = do
tyEnv <- getValueEnv
instUnivExist (constrType c tyEnv)
exprType (Typed e _) = exprType e
exprType (Paren e) = exprType e
exprType (Record c fs) = do
tyEnv <- getValueEnv
ty <- liftM arrowBase $ instUnivExist $ constrType c tyEnv
mapM_ (fieldType exprType ty) fs
return ty
exprType (RecordUpdate e fs) = do
ty <- exprType e
mapM_ (fieldType exprType ty) fs
return ty
exprType (Tuple _ es)
| null es = return unitType
| otherwise = liftM tupleType $ mapM exprType es
exprType (List _ es) = freshTypeVar >>= flip elemType es
where elemType ty [] = return (listType ty)
elemType ty (e:es1) = exprType e >>= unify ty >> elemType ty es1
exprType (ListCompr _ e _) = liftM listType $ exprType e
exprType (EnumFrom _) = return (listType intType)
exprType (EnumFromThen _ _) = return (listType intType)
exprType (EnumFromTo _ _) = return (listType intType)
exprType (EnumFromThenTo _ _ _) = return (listType intType)
exprType (UnaryMinus _ e) = exprType e
exprType (Apply e1 e2) = do
(ty1,ty2) <- exprType e1 >>= unifyArrow
exprType e2 >>= unify ty1
return ty2
exprType (InfixApply e1 op e2) = do
(ty1,ty2,ty3) <- exprType (infixOp op) >>= unifyArrow2
exprType e1 >>= unify ty1
exprType e2 >>= unify ty2
return ty3
exprType (LeftSection e op) = do
(ty1,ty2,ty3) <- exprType (infixOp op) >>= unifyArrow2
exprType e >>= unify ty1
return (TypeArrow ty2 ty3)
exprType (RightSection op e) = do
(ty1,ty2,ty3) <- exprType (infixOp op) >>= unifyArrow2
exprType e >>= unify ty2
return (TypeArrow ty1 ty3)
exprType (Lambda _ args e) = do
tys <- mapM argType args
ty <- exprType e
return (foldr TypeArrow ty tys)
exprType (Let _ e) = exprType e
exprType (Do _ e) = exprType e
exprType (IfThenElse _ e1 e2 e3) = do
exprType e1 >>= unify boolType
ty2 <- exprType e2
ty3 <- exprType e3
unify ty2 ty3
return ty3
exprType (Case _ _ _ alts) = freshTypeVar >>= flip altType alts
where altType ty [] = return ty
altType ty (Alt _ _ rhs:alts1) =
rhsType rhs >>= unify ty >> altType ty alts1
rhsType :: Rhs -> TCM Type
rhsType (SimpleRhs _ e _) = exprType e
rhsType (GuardedRhs es _) = freshTypeVar >>= flip condExprType es
where condExprType ty [] = return ty
condExprType ty (CondExpr _ _ e:es1) =
exprType e >>= unify ty >> condExprType ty es1
fieldType :: (a -> TCM Type) -> Type -> Field a -> TCM Type
fieldType tcheck ty (Field _ l x) = do
tyEnv <- getValueEnv
TypeArrow ty1 ty2 <- instUniv (labelType l tyEnv)
unify ty ty1
lty <- tcheck x
unify ty2 lty
return lty
freshTypeVar :: TCM Type
freshTypeVar = TypeVariable `liftM` getNextId
instType :: Int -> Type -> TCM Type
instType n ty = do
tys <- replicateM n freshTypeVar
return (expandAliasType tys ty)
instUniv :: TypeScheme -> TCM Type
instUniv (ForAll n ty) = instType n ty
instUnivExist :: ExistTypeScheme -> TCM Type
instUnivExist (ForAllExist n n' ty) = instType (n + n') ty
unify :: Type -> Type -> TCM ()
unify ty1 ty2 = do
theta <- getTypeSubst
let ty1' = subst theta ty1
ty2' = subst theta ty2
modifyTypeSubst $ unifyTypes ty1' ty2'
unifyList :: [Type] -> [Type] -> TCM ()
unifyList tys1 tys2 = zipWithM_ unify tys1 tys2
unifyArrow :: Type -> TCM (Type,Type)
unifyArrow ty = do
theta <- getTypeSubst
case subst theta ty of
TypeVariable tv
| tv >= 0 -> do
ty1 <- freshTypeVar
ty2 <- freshTypeVar
modifyTypeSubst (bindVar tv (TypeArrow ty1 ty2))
return (ty1,ty2)
TypeArrow ty1 ty2 -> return (ty1,ty2)
ty' -> internalError ("Base.Typing.unifyArrow (" ++ show ty' ++ ")")
unifyArrow2 :: Type -> TCM (Type,Type,Type)
unifyArrow2 ty = do
(ty1,ty2) <- unifyArrow ty
(ty21,ty22) <- unifyArrow ty2
return (ty1,ty21,ty22)
unifyTypes :: Type -> Type -> TypeSubst -> TypeSubst
unifyTypes (TypeVariable tv1) (TypeVariable tv2) theta
| tv1 == tv2 = theta
unifyTypes (TypeVariable tv) ty theta
| tv >= 0 = bindVar tv ty theta
unifyTypes ty (TypeVariable tv) theta
| tv >= 0 = bindVar tv ty theta
unifyTypes (TypeConstructor tc1 tys1) (TypeConstructor tc2 tys2) theta
| tc1 == tc2 = foldr2 unifyTypes theta tys1 tys2
unifyTypes (TypeConstrained _ tv1) (TypeConstrained _ tv2) theta
| tv1 == tv2 = theta
unifyTypes (TypeArrow ty11 ty12) (TypeArrow ty21 ty22) theta =
unifyTypes ty11 ty21 (unifyTypes ty12 ty22 theta)
unifyTypes (TypeSkolem k1) (TypeSkolem k2) theta
| k1 == k2 = theta
unifyTypes ty1 ty2 _ = internalError $
"Base.Typing.unify: (" ++ show ty1 ++ ") (" ++ show ty2 ++ ")"
constrType :: QualIdent -> ValueEnv -> ExistTypeScheme
constrType c tyEnv = case qualLookupValue c tyEnv of
[DataConstructor _ _ _ sigma] -> sigma
[NewtypeConstructor _ _ sigma] -> sigma
_ -> internalError $ "Base.Typing.constrType: " ++ show c
varType :: Ident -> ValueEnv -> TypeScheme
varType v tyEnv = case lookupValue v tyEnv of
[Value _ _ sigma] -> sigma
[Label _ _ sigma] -> sigma
_ -> internalError $ "Base.Typing.varType: " ++ show v
funType :: QualIdent -> ValueEnv -> TypeScheme
funType f tyEnv = case qualLookupValue f tyEnv of
[Value _ _ sigma] -> sigma
[Label _ _ sigma] -> sigma
_ -> internalError $ "Base.Typing.funType: " ++ show f
labelType :: QualIdent -> ValueEnv -> TypeScheme
labelType l tyEnv = case qualLookupValue l tyEnv of
[Label _ _ sigma] -> sigma
_ -> internalError $ "Base.Typing.labelType: " ++ show l