ivo/src/Ivo/Types.hs

174 lines
5.5 KiB
Haskell

module Ivo.Types
( module Ivo.Types.Base
, infer, check
) where
import Ivo.Syntax.Printer
import Ivo.Types.Base
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.Except (MonadError, throwError)
import Control.Monad.Reader (MonadReader, runReader, asks, local)
import Control.Monad.RWS
( RWST, evalRWST
, MonadState, state
, MonadWriter, tell, listen
)
import Data.Foldable (forM_, toList)
import Data.HashSet qualified as HS
import Data.HashMap.Strict qualified as HM
import Data.Stream (Stream (..), fromList)
import Data.Text qualified as T
fresh :: MonadState (Stream Text) m => m Type
fresh = state \(Cons n ns) -> (TVar n, ns)
inst :: MonadState (Stream Text) m => Scheme -> m Type
inst (TForall ns t) = foldr (\n_n t' -> substitute1 n_n <$> fresh <*> t') (pure t) ns
lookupVar :: (MonadReader Context m, MonadState (Stream Text) m, MonadError Text m) => Text -> m Type
lookupVar n = do
t_polyq <- asks (HM.!? n)
case t_polyq of
Nothing -> throwError $ "Variable not bound: " <> n
Just t_poly -> inst t_poly
generalize :: MonadReader Context m => Type -> m Scheme
generalize t = do
ctx <- asks HM.keysSet
pure $ TForall (toList $ HS.difference (free t) ctx) t
bindVar :: MonadReader Context m => Text -> Type -> m a -> m a
bindVar n t = local (HM.insert n (TForall [] t))
unify :: MonadWriter [Constraint] m => Type -> Type -> m ()
unify t1 t2 = tell [(t1, t2)]
ctrTy :: MonadState (Stream Text) m => Ctr -> m (Type, [Type])
ctrTy = \case
CUnit -> pure (TUnit, [])
CZero -> pure (TNat, [])
CSucc -> pure (TNat, [TNat])
CChar -> pure (TChar, [TNat])
CNil -> mkUnary TList $ const []
CCons -> mkUnary TList \t_a -> [t_a, TApp TList t_a]
CPair -> mkBinary TProd \t_a t_b -> [t_a, t_b]
CLeft -> mkBinary TSum \t_a _ -> [t_a]
CRight -> mkBinary TSum \_ t_b -> [t_b]
where
mkBinary tc tcas = do
t_a <- fresh
t_b <- fresh
pure (tapp [tc, t_a, t_b], tcas t_a t_b)
mkUnary tc tcas = do
t_a <- fresh
pure (TApp tc t_a, tcas t_a)
j :: (MonadError Text m, MonadReader Context m, MonadState (Stream Text) m, MonadWriter [Constraint] m)
=> CheckExpr -> m Type
j (Var name) = lookupVar name
j (App e_fun e_arg) = do
t_ret <- fresh
t_fun <- j e_fun
t_arg <- j e_arg
unify t_fun (tapp [TAbs, t_arg, t_ret])
pure t_ret
j (Abs n_arg e_ret) = do
t_arg <- fresh
t_ret <- bindVar n_arg t_arg $ j e_ret
pure $ tapp [TAbs, t_arg, t_ret]
j (Let (n_x, e_x) e_ret) = do
(t_x_mono, c) <- listen $ j e_x
s <- solve' c
t_x_poly <- generalize $ substitute s t_x_mono
local (HM.insert n_x t_x_poly) $ j e_ret
-- In a case expression:
-- * the pattern for each branch has the same type as the expression being matched, and
-- * the return type for each branch has the same type as the return type of the case expression as a whole.
j (Case ctrs) = do
t_ret <- fresh
t_x <- fresh
forM_ ctrs \(Pat ctr ns_n e) -> do
(t_x', ts_n) <- ctrTy ctr
unify t_x t_x'
when (length ts_n /= length ns_n) $ throwError "Constructor arity mismatch"
t_ret' <- local (HM.union $ HM.fromList $ zip ns_n $ map (TForall []) ts_n) $ j e
unify t_ret t_ret'
pure $ tapp [TAbs, t_x, t_ret]
j (CtrC ctr) = do
(t_ret, ts_n) <- ctrTy ctr
pure $ foldr (\t_a t_r -> tapp [TAbs, t_a, t_r]) t_ret ts_n
j (Ann () e t_ann) = do
t_ret <- j e
unify t_ret t_ann
pure t_ann
j CallCCC = do
t_a <- fresh
t_b <- fresh
pure $ tapp [TAbs, tapp [TAbs, tapp [TAbs, t_a, t_b], t_a], t_a]
j FixC = do
t_a <- fresh
pure $ tapp [TAbs, tapp [TAbs, t_a, t_a], t_a]
j HoleC = asks show >>= throwError . (<>) "Encountered hole with context: " . T.pack
occurs :: Text -> Type -> Bool
occurs n t = HS.member n (free t)
findDifference :: MonadError (Type, Type) m => Type -> Type -> m (Maybe (Text, Type))
findDifference t1 t2
| t1 == t2 = pure Nothing
| TVar n1 <- t1, not (occurs n1 t2) = pure $ Just (n1, t2)
| TVar _ <- t2 = findDifference t2 t1
| TApp a1 b1 <- t1, TApp a2 b2 <- t2 = (<|>) <$> findDifference a1 a2 <*> findDifference b1 b2
| otherwise = throwError (t1, t2)
unifies :: MonadError (Type, Type) m => Type -> Type -> m Substitution
unifies t1 t2 = do
dq <- findDifference t1 t2
case dq of
Nothing -> pure HM.empty
Just s -> do
ss <- unifies (uncurry substitute1 s t1) (uncurry substitute1 s t2)
pure $ uncurry HM.insert (fmap (substitute ss) s) ss
solve :: MonadError (Type, Type) m => [Constraint] -> m Substitution
solve [] = pure HM.empty
solve (c:cs) = do
s <- uncurry unifies c
ss <- solve (substituteMono s cs)
pure $ HM.union ss (substituteMono ss s)
solve' :: MonadError Text m => [Constraint] -> m Substitution
solve' c = case solve c of
Right ss -> pure ss
Left (t1, t2) -> throwError $ "Could not unify " <> unparseType t1 <> " with " <> unparseType t2
type Inferencer a = RWST Context [Constraint] (Stream Text) (Either Text) a
runInferencer :: Inferencer a -> Either Text (a, [Constraint])
runInferencer m = evalRWST m HM.empty freshNames
where
freshNames = fromList $ do
n <- [0 :: Int ..]
c <- ['a'..'z']
pure $ T.pack if n == 0 then [c] else c : show n
infer :: CheckExpr -> Either Text Scheme
infer e = do
(t, c) <- runInferencer $ j e
s <- solve' c
let t' = substitute s t
pure $ runReader (generalize t') HM.empty
check :: Type -> CheckExpr -> Either Text Scheme
check t_ann e = do
(t, c) <- runInferencer do
t_ret <- j e
unify t_ret t_ann
pure t_ann
s <- solve' c
let t' = substitute s t
pure $ runReader (generalize t') HM.empty