From 9e0754daf6d0c090707498261750d2564ea3d342 Mon Sep 17 00:00:00 2001 From: James Martin Date: Thu, 18 Mar 2021 00:00:43 -0700 Subject: [PATCH] Added support for Hindley-Milner type inference! I also generalized substitution to *all* expressions (and types), which itself involved a rewrite of the substitution function I already had. (In hindsight, I wish I'd done that in a separate commit.) The last four days have all been working up to this, so I'm glad I've finally managed to integrate it into this project! (I wrote the algorithm 4 days ago, but the infrastructure just wasn't there to add it here.) --- README.md | 87 ++++--- app/Main.hs | 8 +- package.yaml | 3 +- src/LambdaCalculus.hs | 10 +- src/LambdaCalculus/Evaluator.hs | 150 ++---------- src/LambdaCalculus/Evaluator/Base.hs | 75 ++++-- src/LambdaCalculus/Evaluator/Continuation.hs | 2 +- src/LambdaCalculus/Expression.hs | 86 +++++-- src/LambdaCalculus/Expression/Base.hs | 136 ++++++++++- src/LambdaCalculus/Syntax/Base.hs | 22 +- src/LambdaCalculus/Syntax/Parser.hs | 70 +++--- src/LambdaCalculus/Syntax/Printer.hs | 7 +- src/LambdaCalculus/Types.hs | 158 ++++++++++++ src/LambdaCalculus/Types/Base.hs | 242 +++++++++++++++++++ test/Spec.hs | 4 +- 15 files changed, 806 insertions(+), 254 deletions(-) create mode 100644 src/LambdaCalculus/Types.hs create mode 100644 src/LambdaCalculus/Types/Base.hs diff --git a/README.md b/README.md index fe75879..c9dec64 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,14 @@ # Lambda Calculus -This is a simple programming language derived from lambda calculus. +This is a simple programming language derived from lambda calculus, +using the Hindley-Milner type system, plus some builtin types, `fix`, and `callcc` ## Usage Run the program using `stack run` (or run the tests with `stack test`). -Type in your expression at the prompt: `>> `. -The expression will be evaluated to normal form using the call-by-value evaluation strategy and then printed. +Type in your expression at the prompt: `>> `. This will happen: +* the type for the expression will be inferred and then printed, +* then, regardless of whether typechecking succeeded, expression will be evaluated to normal form using the call-by-value evaluation strategy and then printed. + Exit the prompt with `Ctrl-c` (or equivalent). ## Syntax @@ -16,23 +19,38 @@ The parser's error messages currently are virtually useless, so be very careful * Lambda abstraction: `\x y z. E` or `λx y z. E` * Let expressions: `let x = E; y = F in G` * Parenthetical expressions: `(E)` -* Constructors: `()`, `(x, y)` (or `(,) x y`), `Left x`, `Right y`, `Z`, `S`, `[]`, `(x : xs)` (or `(:) x xs`), `Char n`. +* Constructors: `()`, `(x, y)` (or `(,) x y`), `Left x`, `Right y`, `Z`, `S`, `[]`, `(x :: xs)` (or `(:) x xs`), `Char n`. * The parentheses around the cons constructor are not optional. * `Char` takes a natural number and turns it into a character. -* Pattern matchers: `{ Left x -> e ; Right y -> f }` +* Pattern matchers: `case { Left a -> e ; Right y -> f }` * Pattern matchers can be applied like functions, e.g. `{ Z -> x, S -> y } 10` reduces to `y`. - * Patterns must use the regular form of the constructor, e.g. `(x : xs)` and not `((:) x xs)`. + * Patterns must use the regular form of the constructor, e.g. `(x :: xs)` and not `((::) x xs)`. * There are no nested patterns or default patterns. * Incomplete pattern matches will crash the interpreter. * Literals: `1234`, `[e, f, g, h]`, `'a`, `"abc"` * Strings are represented as lists of characters. +* Type annotations: there are no type annotations; types are inferred only. -## Call/CC -This interpreter has preliminary support for -[the call-with-current-continuation control flow operator](https://en.wikipedia.org/wiki/Call-with-current-continuation). -However, it has not been thoroughly tested. +## Types +Types are checked/inferred using the Hindley-Milner type inference algorithm. -To use it, simply apply the variable `callcc` like you would a function, e.g. `(callcc (\k. ...))`. +* Functions: `a -> b` (constructed by `\x. e`) +* Products: `a * b` (constructed by `(x, y)`) +* Unit: `★` (constructed by `()`) +* Sums: `a + b` (constructed by `Left x` or `Right y`) +* Bottom: `⊥` (currently useless because incomplete patterns are allowed) +* The natural numbers: `Nat` (constructed by `Z` and `S`) +* Lists: `List a` (constructed by `[]` and `(x :: xs)`) +* Characters: `Char` (constructed by `Char`, which takes a `Nat`) +* Universal quantification (forall): `∀a b. t` + +## Builtins +Builtins are variables that correspond with a built-in language feature +that cannot be replicated by user-written code. +They still are just variables though; they do not receive special syntactic treatment. + +* `fix : ∀a. ((a -> a) -> a)`: an alias for the strict fixpoint combinator that the typechecker can typecheck. +* `callcc : ∀a b. (((a -> b) -> a) -> a)`: [the call-with-current-continuation control flow operator](https://en.wikipedia.org/wiki/Call-with-current-continuation). Continuations are printed as `λ!. ... ! ...`, like a lambda abstraction with an argument named `!` which is used exactly once; @@ -41,56 +59,61 @@ because they perform the side effect of modifying the current continuation, and this is *not* valid syntax you can input into the REPL. ## Example code -The fixpoint function: -``` -(\x. x x) \fix f x. f (fix fix f) x -``` - Create a list by iterating `f` `n` times: ``` -fix \iterate f x. { Z -> x ; S n -> iterate f (f x) n } +fix \iterate f x. { Z -> [] ; S n -> (x :: iterate f (f x) n) } +: ∀e. ((e -> e) -> (e -> (Nat -> [e]))) ``` -Create a list whose first element is `n - 1`, counting down to a last element of `0`: +Use the iterate function to count to 10: ``` -\n. { (n, x) -> x } (iterate { (n, x) -> (S n, (n : x)) } (0, []) n) -``` - -Putting it all together to count down from 10: -``` ->> let fix = (\x. x x) \fix f x. f (fix fix f) x; iterate = fix \iterate f x. { Z -> x ; S n -> iterate f (f x) n }; countDownFrom = \n. { (n, x) -> x } (iterate { (n, x) -> (S n, (n : x)) } (0, []) n) in countDownFrom 10 -[9, 8, 7, 6, 5, 4, 3, 2, 1, 0] +>> let iterate = fix \iterate f x. { Z -> [] ; S n -> (x :: iterate f (f x) n) }; countTo = iterate S 1 in countTo 10 +: [Nat] +[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ``` Append two lists together: ``` -fix \append xs ys. { [] -> ys ; (x : xs) -> (x : append xs ys) } xs +fix \append xs ys. { [] -> ys ; (x :: xs) -> (x :: append xs ys) } xs +: ∀j. ([j] -> ([j] -> [j])) ``` Reverse a list: ``` -fix \reverse. { [] -> [] ; (x : xs) -> append (reverse xs) [x] } +fix \reverse. { [] -> [] ; (x :: xs) -> append (reverse xs) [x] } +: ∀c1. ([c1] -> [c1]) ``` Putting them together so we can reverse `"reverse"`: ``` ->> let fix = (\x. x x) \fix f x. f (fix fix f) x; append = fix \append xs ys. { [] -> ys ; (x : xs) -> (x : append xs ys) } xs; reverse = fix \reverse. { [] -> [] ; (x : xs) -> append (reverse xs) [x] } in reverse "reverse" +>> let append = fix \append xs ys. { [] -> ys ; (x :: xs) -> (x :: append xs ys) } xs; reverse = fix \reverse. { [] -> [] ; (x :: xs) -> append (reverse xs) [x] } in reverse "reverse" +: [Char] "esrever" ``` Calculating `3 + 2` with the help of Church-encoded numerals: ``` >> let Sf = \n f x. f (n f x); plus = \x. x Sf in plus (\f x. f (f (f x))) (\f x. f (f x)) S Z +: Nat 5 ``` This expression would loop forever, but `callcc` saves the day! ``` ->> y (callcc \k. (\x. (\x. x x) (\x. x x)) (k z)) -y z +>> S (callcc \k. (fix \x. x) (k Z)) +: Nat +1 ``` -A few other weird expressions: +And if it wasn't clear, this is what the `Char` constructor does: + +``` +>> { Char c -> Char (S c) } 'a +: Char +'b +``` + +Here are a few expressions which don't typecheck but are handy for debugging the evaluator: ``` >> let D = \x. x x; F = \f. f (f y) in D (F \x. x) y y @@ -98,6 +121,4 @@ y y y >> (\x y z. x y) y λy' z. y y' ->> { Char c -> Char (S c) } 'a -'b ``` diff --git a/app/Main.hs b/app/Main.hs index cd99939..358c5a1 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -15,8 +15,10 @@ prompt text = do getLine main :: IO () -main = forever $ parseEval <$> prompt ">> " >>= \case +main = forever $ parseCheck <$> prompt ">> " >>= \case Left parseError -> putStrLn $ "Parse error: " <> pack (show parseError) -- TODO: Support choosing which version to use at runtime. - Right expr -> putStrLn $ unparseEval $ eval expr - --Right expr -> mapM_ (putStrLn . unparseEval) $ snd $ traceEval expr + Right expr -> do + either putStrLn (putStrLn . (": " <>) . unparseScheme) $ infer expr + putStrLn $ unparseEval $ eval $ check2eval expr + --mapM_ (putStrLn . unparseEval) $ snd $ traceEval $ check2eval expr diff --git a/package.yaml b/package.yaml index 15b1ed4..823c836 100644 --- a/package.yaml +++ b/package.yaml @@ -20,9 +20,11 @@ default-extensions: - FlexibleContexts - FlexibleInstances - ImportQualifiedPost +- InstanceSigs - LambdaCase - OverloadedStrings - PatternSynonyms +- ScopedTypeVariables - StandaloneDeriving - ViewPatterns # Required for use of the 'trees that grow' pattern @@ -32,7 +34,6 @@ default-extensions: - DeriveFoldable - DeriveFunctor - DeriveTraversable -- TemplateHaskell dependencies: - base >= 4.14 && < 5 diff --git a/src/LambdaCalculus.hs b/src/LambdaCalculus.hs index 139e6a0..e11a9e1 100644 --- a/src/LambdaCalculus.hs +++ b/src/LambdaCalculus.hs @@ -2,15 +2,23 @@ module LambdaCalculus ( module LambdaCalculus.Evaluator , module LambdaCalculus.Expression , module LambdaCalculus.Syntax - , parseEval, unparseEval + , module LambdaCalculus.Types + , parseCheck, parseEval, unparseCheck, unparseEval ) where import LambdaCalculus.Evaluator import LambdaCalculus.Expression import LambdaCalculus.Syntax +import LambdaCalculus.Types + +parseCheck :: Text -> Either ParseError CheckExpr +parseCheck = fmap ast2check . parseAST parseEval :: Text -> Either ParseError EvalExpr parseEval = fmap ast2eval . parseAST +unparseCheck :: CheckExpr -> Text +unparseCheck = unparseAST . simplify . check2ast + unparseEval :: EvalExpr -> Text unparseEval = unparseAST . simplify . eval2ast diff --git a/src/LambdaCalculus/Evaluator.hs b/src/LambdaCalculus/Evaluator.hs index 2382ce6..82624c0 100644 --- a/src/LambdaCalculus/Evaluator.hs +++ b/src/LambdaCalculus/Evaluator.hs @@ -2,130 +2,29 @@ module LambdaCalculus.Evaluator ( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text , Eval, EvalExpr, EvalX, EvalXF (..) , pattern AppFE, pattern CtrE, pattern CtrFE - , pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF - , eval, traceEval, substitute, alphaConvert + , pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE + , eval, traceEval ) where import LambdaCalculus.Evaluator.Base import LambdaCalculus.Evaluator.Continuation -import Control.Monad (forM) import Control.Monad.Except (MonadError, ExceptT, throwError, runExceptT) -import Control.Monad.State (MonadState, State, evalState, modify', state, put, gets) +import Control.Monad.State (MonadState, evalState, modify', state, put, gets) import Control.Monad.Writer (runWriterT, tell) -import Data.Foldable (fold) -import Data.Functor.Foldable (cata, para, embed) -import Data.HashSet (HashSet) -import Data.HashSet qualified as HS -import Data.Stream qualified as S -import Data.Text qualified as T +import Data.HashMap.Strict qualified as HM import Data.Void (Void, absurd) --- | Free variables are variables which are present in an expression but not bound by any abstraction. -freeVars :: EvalExpr -> HashSet Text -freeVars = cata \case - VarF n -> HS.singleton n - AbsF n e -> HS.delete n e - ContF e -> HS.delete "!" e - CaseF ps -> foldMap (\(Pat _ ns e) -> HS.difference e (HS.fromList ns)) ps - e -> fold e - --- | Bound variables are variables which are bound by any form of abstraction in an expression. -boundVars :: EvalExpr -> HashSet Text -boundVars = cata \case - AbsF n e -> HS.insert n e - ContF e -> HS.insert "!" e - CaseF ps -> foldMap (\(Pat _ ns e) -> HS.union (HS.fromList ns) e) ps - e -> fold e - --- | Vars that occur anywhere in an experession, bound or free. -usedVars :: EvalExpr -> HashSet Text -usedVars x = HS.union (freeVars x) (boundVars x) - --- | Substitution is the process of replacing all free occurrences of a variable in one expression with another expression. -substitute :: Text -> EvalExpr -> EvalExpr -> EvalExpr -substitute var val = unsafeSubstitute var val . alphaConvert (freeVars val) - --- | Substitution is only safe if the bound variables in the body --- are disjoint from the free variables in the argument; --- this function makes an expression body safe for substitution --- by replacing the bound variables in the body --- with completely new variables which do not occur in either expression --- (without changing any *free* variables in the body, of course). -alphaConvert :: HashSet Text -> EvalExpr -> EvalExpr -alphaConvert ctx e_ = evalState (alphaConverter e_) $ HS.union ctx (usedVars e_) - where - alphaConverter :: EvalExpr -> State (HashSet Text) EvalExpr - alphaConverter = cata \case - e - | AbsF n e' <- e, n `HS.member` ctx -> do - n' <- fresh n - e'' <- e' - pure $ Abs n' $ replace n n' e'' - -- | TODO: Only replace the names that *have* to be replaced. - | CaseF ps <- e, any (any (`HS.member` ctx) . patNames) ps -> - Case <$> forM ps \(Pat ctr ns e') -> do - ns' <- mapM fresh ns - e'' <- e' - pure $ Pat ctr ns' $ foldr (uncurry replace) e'' (zip ns ns') - | otherwise -> embed <$> sequenceA e - - -- | Create a new name which is not used anywhere else. - fresh :: Text -> State (HashSet Text) Text - fresh n = state \ctx' -> - let n' = S.head $ S.filter (not . (`HS.member` ctx')) names - in (n', HS.insert n' ctx') - where names = S.iterate (`T.snoc` '\'') n - --- | Replace a name with an entirely new name in all contexts. --- This will only give correct results if --- the new name does not occur anywhere in the expression. -replace :: Text -> Text -> EvalExpr -> EvalExpr -replace name name' = cata \case - e - | VarF name2 <- e, name == name2 -> Var name' - | AbsF name2 e' <- e, name == name2 -> Abs name' e' - | CaseF ps <- e -> Case $ flip map ps \(Pat ctr ns e') -> Pat ctr (replace' ns) e' - | otherwise -> embed e - where - replace' = map \case - n - | n == name -> name' - | otherwise -> n - --- | Substitution which does *not* avoid variable capture; --- it only gives the correct result if the bound variables in the body --- are disjoint from the free variables in the argument. -unsafeSubstitute :: Text -> EvalExpr -> EvalExpr -> EvalExpr -unsafeSubstitute var val = para \case - e' - | VarF var2 <- e', var == var2 -> val - | AbsF var2 _ <- e', var == var2 -> unmodified e' - | ContF _ <- e', var == "!" -> unmodified e' - | CaseF ps <- e' -> Case $ flip map ps \(Pat ctr ns (unmod, sub)) -> - Pat ctr ns if var `elem` ns then unmod else sub - | otherwise -> substituted e' - where - substituted, unmodified :: EvalExprF (EvalExpr, EvalExpr) -> EvalExpr - substituted = embed . fmap snd - unmodified = embed . fmap fst - isReducible :: EvalExpr -> Bool -isReducible = snd . cata \case - AppFE ctr args -> active ctr [args] - AbsF _ _ -> passive - ContF _ -> passive - CaseF _ -> passive - CallCCF -> passive - CtrFE _ -> constant - VarF _ -> constant - where - -- | Constants are irreducible in any context. - constant = (False, False) - -- | Passive expressions are reducible only if an active expression is applied to them. - passive = (True, False) - -- | Active expressions are reducible if they are applied to a constructor or their arguments are reducible. - active ctr args = (False, fst ctr || snd ctr || any snd args) +-- Applications of function type constructors +isReducible (App (Abs _ _) _) = True +isReducible (App (ContE _) _) = True +isReducible (App CallCCE _) = True +-- Pattern matching of data +isReducible (App (Case _) ex) = isData ex || isReducible ex +-- Reducible subexpressions +isReducible (App ef ex) = isReducible ef || isReducible ex +isReducible _ = False lookupPat :: Ctr -> [Pat phase] -> Pat phase lookupPat ctr = foldr lookupCtr' (error "Constructor not found") @@ -182,21 +81,22 @@ evaluatorStep = \case | otherwise -> case ef of -- perform beta reduction if possible... Abs name body -> - pure $ substitute name ex body - Case pats - | isData ex -> do - let (ctr, xs) = toData ex - let Pat _ ns e = lookupPat ctr pats - pure $ foldr (uncurry substitute) e (zip ns xs) - | otherwise -> ret unmodified + pure $ substitute1 name ex body + Case pats -> + if isData ex + then do + let (ctr, xs) = toData ex + let Pat _ ns e = lookupPat ctr pats + pure $ substitute (HM.fromList $ zip ns xs) e + else ret unmodified -- perform continuation calls if possible... - Cont body -> do + ContE body -> do put [] - pure $ substitute "!" ex body + pure $ substitute1 "!" ex body -- capture the current continuation if requested... - CallCC -> do + CallCCE -> do k <- gets $ continue (Var "!") - pure $ App ex (Cont k) + pure $ App ex (ContE k) -- otherwise the value is irreducible and we can continue evaluation. _ -> ret unmodified -- Neither abstractions, constructors nor variables are reducible. diff --git a/src/LambdaCalculus/Evaluator/Base.hs b/src/LambdaCalculus/Evaluator/Base.hs index d2d68ff..111dff6 100644 --- a/src/LambdaCalculus/Evaluator/Base.hs +++ b/src/LambdaCalculus/Evaluator/Base.hs @@ -1,14 +1,22 @@ module LambdaCalculus.Evaluator.Base ( Identity (..) , Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text + , substitute, substitute1, rename, rename1, free, bound, used , Eval, EvalExpr, EvalExprF, EvalX, EvalXF (..) , pattern AppFE, pattern CtrE, pattern CtrFE - , pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF + , pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE ) where import LambdaCalculus.Expression.Base +import Control.Monad (forM) +import Control.Monad.Reader (asks) +import Data.Bifunctor (first) +import Data.Foldable (fold) import Data.Functor.Identity (Identity (..)) +import Data.Functor.Foldable (embed, cata, para) +import Data.HashMap.Strict qualified as HM +import Data.Traversable (for) data Eval type EvalExpr = Expr Eval @@ -33,41 +41,62 @@ data EvalXF r -- -- Continuations do not have any corresponding surface-level syntax, -- but may be printed like a lambda with the illegal variable `!`. - = Cont_ !r + = ContE_ !r -- | Call-with-current-continuation, an evaluator built-in function. - | CallCC_ + | CallCCE_ deriving (Eq, Functor, Foldable, Traversable, Show) -instance RecursivePhase Eval where - projectAppArgs = Identity - embedAppArgs = runIdentity - pattern CtrE :: Ctr -> EvalExpr pattern CtrE c = Ctr c Unit pattern CtrFE :: Ctr -> EvalExprF r pattern CtrFE c = CtrF c Unit -pattern Cont :: EvalExpr -> EvalExpr -pattern Cont e = ExprX (Cont_ e) +pattern ContE :: EvalExpr -> EvalExpr +pattern ContE e = ExprX (ContE_ e) -pattern CallCC :: EvalExpr -pattern CallCC = ExprX CallCC_ +pattern CallCCE :: EvalExpr +pattern CallCCE = ExprX CallCCE_ -pattern ContF :: r -> EvalExprF r -pattern ContF e = ExprXF (Cont_ e) +pattern ContFE :: r -> EvalExprF r +pattern ContFE e = ExprXF (ContE_ e) -pattern CallCCF :: EvalExprF r -pattern CallCCF = ExprXF CallCC_ +pattern CallCCFE :: EvalExprF r +pattern CallCCFE = ExprXF CallCCE_ pattern AppFE :: r -> r -> EvalExprF r pattern AppFE ef ex = AppF ef (Identity ex) -{-# COMPLETE Var, App, Abs, Let, Ctr, Case, Cont, CallCC #-} -{-# COMPLETE VarF, AppF, AbsF, LetF, CtrF, CaseF, ContF, CallCCF #-} -{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ExprXF #-} -{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ContF, CallCCF #-} -{-# COMPLETE Var, App, Abs, Let, CtrE, Case, Cont, CallCC #-} -{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFE, CaseF, ContF, CallCCF #-} -{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ExprXF #-} -{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ContF, CallCCF #-} +{-# COMPLETE Var, App, Abs, Let, Ctr, Case, ContE, CallCCE #-} +{-# COMPLETE VarF, AppF, AbsF, LetF, CtrF, CaseF, ContFE, CallCCFE #-} +{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ExprXF #-} +{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ContFE, CallCCFE #-} +{-# COMPLETE Var, App, Abs, Let, CtrE, Case, ContE, CallCCE #-} +{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFE, CaseF, ContFE, CallCCFE #-} +{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ExprXF #-} +{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ContFE, CallCCFE #-} + +instance RecursivePhase Eval where + projectAppArgs = Identity + embedAppArgs = runIdentity + +instance Substitutable EvalExpr where + collectVars withVar withBinder = cata \case + VarF n -> withVar n + AbsF n e -> withBinder n e + CaseF pats -> foldMap (\(Pat _ ns e) -> foldr withBinder e ns) pats + e -> fold e + + rename = runRenamer $ \badNames -> cata \case + VarF n -> asks $ Var . HM.findWithDefault n n + AbsF n e -> uncurry Abs . first runIdentity <$> replaceNames badNames (Identity n) e + ContFE e -> ContE <$> e + CaseF ps -> Case <$> forM ps \(Pat ctr ns e) -> uncurry (Pat ctr) <$> replaceNames badNames ns e + e -> embed <$> sequenceA e + + unsafeSubstitute = runSubstituter $ para \case + VarF name -> asks $ HM.findWithDefault (Var name) name + AbsF name e -> Abs name <$> maySubstitute (Identity name) e + ContFE e -> ContE <$> maySubstitute (Identity "!") e + CaseF pats -> Case <$> for pats \(Pat ctr ns e) -> Pat ctr ns <$> maySubstitute ns e + e -> embed <$> traverse snd e diff --git a/src/LambdaCalculus/Evaluator/Continuation.hs b/src/LambdaCalculus/Evaluator/Continuation.hs index 009f29f..fa77e8a 100644 --- a/src/LambdaCalculus/Evaluator/Continuation.hs +++ b/src/LambdaCalculus/Evaluator/Continuation.hs @@ -1,6 +1,6 @@ module LambdaCalculus.Evaluator.Continuation ( Continuation, continue, continue1 - , ContinuationCrumb (ApplyTo, AppliedTo, AbstractedOver) + , ContinuationCrumb (..) ) where import LambdaCalculus.Evaluator.Base diff --git a/src/LambdaCalculus/Expression.hs b/src/LambdaCalculus/Expression.hs index 965738e..3f6ff5e 100644 --- a/src/LambdaCalculus/Expression.hs +++ b/src/LambdaCalculus/Expression.hs @@ -1,46 +1,94 @@ module LambdaCalculus.Expression ( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), DefF (..), VoidF, UnitF (..), Text + , substitute, substitute1, rename, free, bound, used , Eval, EvalExpr, EvalX, EvalXF (..), Identity (..) , pattern AppFE, pattern CtrE, pattern CtrFE, - pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF + pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE , Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..), NonEmpty (..), simplify , pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF - , pattern PChar, pattern PCharF, pattern PStr, pattern PStrF - , ast2eval, eval2ast + , pattern PChar, pattern PCharF, pattern PStr, pattern PStrF, pattern HoleP, pattern HoleFP + , Check, CheckExpr, CheckExprF, CheckX, CheckXF (..) + , pattern AppFC, pattern CtrC, pattern CtrFC, pattern CallCCC, pattern CallCCFC + , pattern FixC, pattern FixFC, pattern HoleC, pattern HoleFC + , Type (..), TypeF (..), Scheme (..), tapp + , ast2check, ast2eval, check2eval, check2ast, eval2ast ) where import LambdaCalculus.Evaluator.Base -import LambdaCalculus.Evaluator (alphaConvert, substitute) import LambdaCalculus.Syntax.Base +import LambdaCalculus.Types.Base import Data.Functor.Foldable (cata, hoist) -import Data.HashSet qualified as HS +import Data.HashMap.Strict (HashMap) +import Data.HashMap.Strict qualified as HM import Data.List (foldl') import Data.List.NonEmpty (toList) import Data.Text (unpack) --- | Convert from an abstract syntax tree to an evaluator expression. -ast2eval :: AST -> EvalExpr -ast2eval = substitute "callcc" CallCC . cata \case +builtins :: HashMap Text CheckExpr +builtins = HM.fromList [("callcc", CallCCC), ("fix", FixC)] + +-- | Convert from an abstract syntax tree to a typechecker expression. +ast2check :: AST -> CheckExpr +ast2check = substitute builtins . cata \case VarF name -> Var name AppF ef exs -> foldl' App ef $ toList exs AbsF ns e -> foldr Abs e $ toList ns LetF ds e -> let letExpr name val body' = App (Abs name body') val in foldr (uncurry letExpr) e $ getNonEmptyDefFs ds - CtrF ctr es -> foldl' App (CtrE ctr) es + CtrF ctr es -> foldl' App (CtrC ctr) es CaseF ps -> Case ps PNatF n -> int2ast n PListF es -> mkList es - PStrF s -> mkList $ map (App (CtrE CChar) . int2ast . fromEnum) $ unpack s - PCharF c -> App (CtrE CChar) (int2ast $ fromEnum c) + PStrF s -> mkList $ map (App (CtrC CChar) . int2ast . fromEnum) $ unpack s + PCharF c -> App (CtrC CChar) (int2ast $ fromEnum c) + HoleFP -> HoleC where - int2ast :: Int -> EvalExpr - int2ast 0 = CtrE CZero - int2ast n = App (CtrE CSucc) (int2ast (n - 1)) + int2ast :: Int -> CheckExpr + int2ast 0 = CtrC CZero + int2ast n = App (CtrC CSucc) (int2ast (n - 1)) - mkList :: [EvalExpr] -> EvalExpr - mkList = foldr (App . App (CtrE CCons)) (CtrE CNil) + mkList :: [CheckExpr] -> CheckExpr + mkList = foldr (App . App (CtrC CCons)) (CtrC CNil) + +-- | Convert from a typechecker expression to an evaluator expression. +check2eval :: CheckExpr -> EvalExpr +check2eval = cata \case + VarF name -> Var name + AppFC ef ex -> App ef ex + AbsF n e -> Abs n e + LetF (Def nx ex) e -> App (Abs nx e) ex + CtrFC ctr -> CtrE ctr + CaseF ps -> Case ps + CallCCFC -> CallCCE + FixFC -> z + HoleFC -> omega + where + z, omega :: EvalExpr + z = App omega $ Abs "fix" $ Abs "f" $ Abs "x" $ + App (App (Var "f") (App (App (Var "fix") (Var "fix")) (Var "f"))) (Var "x") + omega = Abs "x" (App (Var "x") (Var "x")) + +-- | Convert from an abstract syntax tree to an evaluator expression. +ast2eval :: AST -> EvalExpr +ast2eval = check2eval . ast2check + +-- | Convert from a typechecker expression to an abstract syntax tree. +check2ast :: CheckExpr -> AST +check2ast = hoist go . rename (HM.keysSet builtins) + where + go :: CheckExprF r -> ASTF r + go = \case + VarF name -> VarF name + AppFC ef ex -> AppF ef (ex :| []) + AbsF n e -> AbsF (n :| []) e + LetF (Def nx ex) e -> LetFP ((nx, ex) :| []) e + CtrFC ctr -> CtrF ctr [] + CaseF ps -> CaseF ps + CallCCFC-> VarF "callcc" + FixFC -> VarF "fix" + HoleFC -> HoleFP -- | Convert from an evaluator expression to an abstract syntax tree. eval2ast :: EvalExpr -> AST @@ -48,14 +96,14 @@ eval2ast :: EvalExpr -> AST -- all instances of `callcc` must be bound; -- therefore, we are free to alpha convert them, -- freeing the name `callcc` for us to use for the built-in again. -eval2ast = hoist go . alphaConvert (HS.singleton "callcc") +eval2ast = hoist go . rename (HM.keysSet builtins) where go :: EvalExprF r -> ASTF r go = \case VarF name -> VarF name - CallCCF -> VarF "callcc" AppFE ef ex -> AppF ef (ex :| []) AbsF n e -> AbsF (n :| []) e CtrFE ctr -> CtrF ctr [] CaseF ps -> CaseF ps - ContF e -> AbsF ("!" :| []) e + CallCCFE -> VarF "callcc" + ContFE e -> AbsF ("!" :| []) e diff --git a/src/LambdaCalculus/Expression/Base.hs b/src/LambdaCalculus/Expression/Base.hs index b0bb619..62ea39f 100644 --- a/src/LambdaCalculus/Expression/Base.hs +++ b/src/LambdaCalculus/Expression/Base.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} module LambdaCalculus.Expression.Base ( Text, VoidF, UnitF (..), absurd' @@ -5,11 +6,24 @@ module LambdaCalculus.Expression.Base , ExprF (..), PatF (..), DefF (..), AppArgsF, LetArgsF, CtrArgsF, XExprF , RecursivePhase, projectAppArgs, projectLetArgs, projectCtrArgs, projectXExpr, projectDef , embedAppArgs, embedLetArgs, embedCtrArgs, embedXExpr, embedDef + , Substitutable, free, bound, used, collectVars, rename, rename1 + , substitute, substitute1, unsafeSubstitute, unsafeSubstitute1 + , runRenamer, freshVar, replaceNames, runSubstituter, maySubstitute ) where +import Control.Monad.Reader (MonadReader, Reader, runReader, asks, local) +import Control.Monad.State (MonadState, StateT, evalStateT, state) +import Control.Monad.Zip (MonadZip, mzipWith) +import Data.Foldable (fold) import Data.Functor.Foldable (Base, Recursive, Corecursive, project, embed) +import Data.HashMap.Strict (HashMap) +import Data.HashMap.Strict qualified as HM +import Data.HashSet (HashSet) +import Data.HashSet qualified as HS import Data.Kind (Type) +import Data.Stream qualified as S import Data.Text (Text) +import Data.Text qualified as T data Expr phase -- | A variable: `x`. @@ -100,8 +114,8 @@ type family LetArgsF phase :: Type -> Type type family CtrArgsF phase :: Type -> Type type family XExprF phase :: Type -> Type -data DefF r = DefF !Text !r - deriving (Eq, Functor, Show) +data DefF r = Def !Text !r + deriving (Eq, Functor, Foldable, Traversable, Show) -- | A contractible data type with one extra type parameter. data UnitF a = Unit @@ -199,10 +213,10 @@ class Functor (ExprF phase) => RecursivePhase phase where embedXExpr = id projectDef :: Def phase -> DefF (Expr phase) -projectDef = uncurry DefF +projectDef = uncurry Def embedDef :: DefF (Expr phase) -> Def phase -embedDef (DefF n e) = (n, e) +embedDef (Def n e) = (n, e) instance RecursivePhase phase => Recursive (Expr phase) where project = \case @@ -227,3 +241,117 @@ instance RecursivePhase phase => Corecursive (Expr phase) where --- --- End base functor boilerplate. --- + +class Substitutable e where + -- | Fold over the variables in the expression with a monoid, + -- given what to do with variable usage sites and binding sites respectively. + collectVars :: Monoid m => (Text -> m) -> (Text -> m -> m) -> e -> m + + -- | Free variables are variables which occur anywhere in an expression + -- where they are not bound by an abstraction. + free :: e -> HashSet Text + free = collectVars HS.singleton HS.delete + + -- | Bound variables are variables which are abstracted over anywhere in an expression. + bound :: e -> HashSet Text + bound = collectVars (const HS.empty) HS.insert + + -- | Used variables are variables which appear *anywhere* in an expression, free or bound. + used :: e -> HashSet Text + used = collectVars HS.singleton HS.insert + + -- | Given a map between variable names and expressions, + -- replace each free occurrence of a variable with its respective expression. + substitute :: HashMap Text e -> e -> e + substitute substs = unsafeSubstitute substs . rename (foldMap free substs) + + substitute1 :: Text -> e -> e -> e + substitute1 n e = substitute (HM.singleton n e) + + -- | Rename all bound variables in an expression (both a binding sites and usage sites) + -- with new names where the new names are *not* members of the provided set. + rename :: HashSet Text -> e -> e + + rename1 :: Text -> e -> e + rename1 n = rename (HS.singleton n) + + -- | A variant of substitution which does *not* avoid variable capture; + -- it only gives the correct result if the bound variables in the body + -- are disjoint from the free variables in the argument. + unsafeSubstitute :: HashMap Text e -> e -> e + + unsafeSubstitute1 :: Text -> e -> e -> e + unsafeSubstitute1 n e = unsafeSubstitute (HM.singleton n e) + +-- +-- These primitives are likely to be useful for implementing `rename`. +-- Ideally, I would like to find a way to move the implementation of `rename` here entirely, +-- but I haven't yet figured out an appropriate abstraction to do so. +-- + +-- | Run an action which requires a stateful context of used variable names +-- and a local context of variable replacements. +-- +-- This is a useful monad for implementing the `rename` function. +runRenamer :: Substitutable e + => (HashSet Text -> e -> StateT (HashSet Text) (Reader (HashMap Text Text)) a) + -> HashSet Text + -> e + -> a +runRenamer m ctx e = runReader (evalStateT (m ctx e) dirtyNames) HM.empty + where dirtyNames = HS.union ctx (used e) + +-- | Create a new variable name within a context of used variable names. +freshVar :: MonadState (HashSet Text) m => Text -> m Text +freshVar baseName = + state \ctx -> let name = newName ctx in (name, HS.insert name ctx) + where + names = S.iterate (`T.snoc` '\'') baseName + newName ctx = S.head $ S.filter (not . (`HS.member` ctx)) names + +-- | Replace a collection of old variable names with new variable names +-- and apply those replacements within a context. +replaceNames :: ( MonadReader (HashMap Text Text) m + , MonadState (HashSet Text) m + , MonadZip t, Traversable t + ) + => HashSet Text -> t Text -> m a -> m (t Text, a) +replaceNames badNames names m = do + newNames <- mapM freshVarIfNecessary names + let replacements = HM.filterWithKey (/=) $ fold $ mzipWith HM.singleton names newNames + x <- local (HM.union replacements) m + pure (newNames, x) + where + freshVarIfNecessary name + | name `HS.member` badNames = freshVar name + | otherwise = pure name + +--- +--- The same as the above section but for `substitute`. +--- This is useful when implementing substitution as a paramorphism. +--- + +-- | Run an action in a local context of substitutions. +-- +-- This monad is useful for implementing, you guessed it, substitution. +runSubstituter :: (e -> Reader (HashMap Text e) a) + -> HashMap Text e + -> e + -> a +runSubstituter m substs e = runReader (m e) substs + +-- | Apply only the substitutions which are not bound, +-- and only if there are substitutions left to apply. +maySubstitute :: ( MonadReader (HashMap Text b) m + , Functor t, Foldable t + ) + => t Text -> (a, m a) -> m a +maySubstitute ns (unmodified, substituted) = + local (compose $ fmap HM.delete ns) do + noMoreSubsts <- asks HM.null + if noMoreSubsts + then pure unmodified + else substituted + +compose :: Foldable t => t (a -> a) -> a -> a +compose = foldr (.) id diff --git a/src/LambdaCalculus/Syntax/Base.hs b/src/LambdaCalculus/Syntax/Base.hs index 81161c3..18610e9 100644 --- a/src/LambdaCalculus/Syntax/Base.hs +++ b/src/LambdaCalculus/Syntax/Base.hs @@ -1,8 +1,9 @@ module LambdaCalculus.Syntax.Base ( Expr (..), ExprF (..), Ctr (..), Pat, Def, DefF (..), PatF (..), VoidF, Text, NonEmpty (..) + , substitute, substitute1, rename, rename1, free, bound, used , Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..) , pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF - , pattern PChar, pattern PCharF, pattern PStr, pattern PStrF + , pattern PChar, pattern PCharF, pattern PStr, pattern PStrF, pattern HoleP, pattern HoleFP , simplify ) where @@ -53,6 +54,8 @@ data ASTXF r | PChar_ Char -- | A string literal, e.g. `"abcd"`. | PStr_ Text + -- | A type hole. + | HoleP_ deriving (Eq, Functor, Foldable, Traversable, Show) instance RecursivePhase Parse where @@ -62,7 +65,6 @@ instance RecursivePhase Parse where newtype NonEmptyDefFs r = NonEmptyDefFs { getNonEmptyDefFs :: NonEmpty (Text, r) } deriving (Eq, Functor, Foldable, Traversable, Show) - pattern LetFP :: NonEmpty (Text, r) -> r -> ASTF r pattern LetFP ds e = LetF (NonEmptyDefFs ds) e @@ -90,10 +92,18 @@ pattern PStrF s = ExprXF (PStr_ s) pattern PStr :: Text -> AST pattern PStr s = ExprX (PStr_ s) -{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, ExprXF #-} -{-# COMPLETE Var, App, Abs, Let, Ctr, Case, PNat, PList, PChar, PStr #-} -{-# COMPLETE VarF, AppF, AbsF, LetF , CtrF, CaseF, PNatF, PListF, PCharF, PStrF #-} -{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, PNatF, PListF, PCharF, PStrF #-} +pattern HoleP :: AST +pattern HoleP = ExprX HoleP_ + +pattern HoleFP :: ASTF r +pattern HoleFP = ExprXF HoleP_ + +{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, ExprXF #-} +{-# COMPLETE Var, App, Abs, Let, Ctr, Case, PNat, PList, PChar, PStr, HoleP #-} +{-# COMPLETE VarF, AppF, AbsF, LetF , CtrF, CaseF, PNatF, PListF, PCharF, PStrF, HoleFP #-} +{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, PNatF, PListF, PCharF, PStrF, HoleFP #-} + +-- TODO: Implement Substitutable for AST. -- | Combine nested expressions into compound expressions or literals when possible. simplify :: AST -> AST diff --git a/src/LambdaCalculus/Syntax/Parser.hs b/src/LambdaCalculus/Syntax/Parser.hs index c3ed956..0876be1 100644 --- a/src/LambdaCalculus/Syntax/Parser.hs +++ b/src/LambdaCalculus/Syntax/Parser.hs @@ -86,10 +86,10 @@ ctr = pair <|> unit <|> either <|> nat <|> list <|> str succ = Ctr CSucc [] <$ keyword "S" natLit = (PNat . read <$> many1 digit) <* spaces list = cons <|> consCtr <|> listLit - consCtr = Ctr CCons [] <$ keyword "(:)" + consCtr = Ctr CCons [] <$ keyword "(::)" cons = try $ between (token '(') (token ')') do e1 <- ambiguous - token ':' + keyword "::" e2 <- ambiguous pure $ Ctr CCons [e1, e2] listLit = fmap PList $ between (token '[') (token ']') $ sepEndBy ambiguous (token ',') @@ -105,35 +105,36 @@ pat = label "case alternate" $ do keyword "->" e <- ambiguous pure $ Pat c ns e - where pair = try $ between (token '(') (token ')') do - e1 <- identifier - token ',' - e2 <- identifier - pure (CPair, [e1, e2]) - unit = (CUnit, []) <$ keyword "()" - left = do - keyword "Left" - e <- identifier - pure (CLeft, [e]) - right = do - keyword "Right" - e <- identifier - pure (CRight, [e]) - zero = (CZero, []) <$ keyword "Z" - succ = do - keyword "S" - e <- identifier - pure (CSucc, [e]) - nil = (CNil, []) <$ keyword "[]" - cons = try $ between (token '(') (token ')') do - e1 <- identifier - token ':' - e2 <- identifier - pure (CCons, [e1, e2]) - char' = do - keyword "Char" - e <- identifier - pure (CChar, [e]) + where + pair = try $ between (token '(') (token ')') do + e1 <- identifier + token ',' + e2 <- identifier + pure (CPair, [e1, e2]) + unit = (CUnit, []) <$ keyword "()" + left = do + keyword "Left" + e <- identifier + pure (CLeft, [e]) + right = do + keyword "Right" + e <- identifier + pure (CRight, [e]) + zero = (CZero, []) <$ keyword "Z" + succ = do + keyword "S" + e <- identifier + pure (CSucc, [e]) + nil = (CNil, []) <$ keyword "[]" + cons = try $ between (token '(') (token ')') do + e1 <- identifier + keyword "::" + e2 <- identifier + pure (CCons, [e1, e2]) + char' = do + keyword "Char" + e <- identifier + pure (CChar, [e]) case_ :: Parser AST case_ = label "case patterns" $ do @@ -142,13 +143,16 @@ case_ = label "case patterns" $ do token '}' pure $ Case pats +hole :: Parser AST +hole = label "hole" $ HoleP <$ token '_' + -- | Guaranteed to consume a finite amount of input finite :: Parser AST -finite = label "finite expression" $ variable <|> ctr <|> case_ <|> grouping +finite = label "finite expression" $ variable <|> hole <|> ctr <|> case_ <|> grouping -- | Guaranteed to consume input, but may continue until it reaches a terminator block :: Parser AST -block = label "block expression" $ finite <|> abstraction <|> let_ +block = label "block expression" $ abstraction <|> let_ <|> finite -- | Not guaranteed to consume input at all, may continue until it reaches a terminator ambiguous :: Parser AST diff --git a/src/LambdaCalculus/Syntax/Printer.hs b/src/LambdaCalculus/Syntax/Printer.hs index 5b0743f..481b1d7 100644 --- a/src/LambdaCalculus/Syntax/Printer.hs +++ b/src/LambdaCalculus/Syntax/Printer.hs @@ -71,6 +71,7 @@ unparseAST = toStrict . toLazyText . snd . cata \case in tag Finite $ "[" <> es' <> "]" PStrF s -> tag Finite $ "\"" <> fromText s <> "\"" PCharF c -> tag Finite $ "'" <> fromLazyText (singleton c) + HoleFP -> tag Finite "_" where unparseApp :: Tagged Builder -> NonEmpty (Tagged Builder) -> Tagged Builder unparseApp ef (unsnoc -> (exs, efinal)) @@ -78,8 +79,8 @@ unparseAST = toStrict . toLazyText . snd . cata \case unparseCtr :: Ctr -> [Tagged Builder] -> Tagged Builder -- Fully-applied special syntax forms - unparseCtr CPair [x, y] = tag Finite $ "(" <> unambiguous x <> ", " <> unambiguous y <> ")" - unparseCtr CCons [x, y] = tag Finite $ "(" <> unambiguous x <> " : " <> unambiguous y <> ")" + unparseCtr CPair [x, y] = tag Finite $ "(" <> unambiguous x <> ", " <> unambiguous y <> ")" + unparseCtr CCons [x, y] = tag Finite $ "(" <> unambiguous x <> " :: " <> unambiguous y <> ")" -- Partially-applied syntax forms unparseCtr CUnit [] = tag Finite "()" unparseCtr CPair [] = tag Finite "(,)" @@ -88,7 +89,7 @@ unparseAST = toStrict . toLazyText . snd . cata \case unparseCtr CZero [] = tag Finite "Z" unparseCtr CSucc [] = tag Finite "S" unparseCtr CNil [] = tag Finite "[]" - unparseCtr CCons [] = tag Finite "(:)" + unparseCtr CCons [] = tag Finite "(::)" unparseCtr CChar [] = tag Finite "Char" unparseCtr ctr (x:xs) = unparseApp (unparseCtr ctr []) (x :| xs) diff --git a/src/LambdaCalculus/Types.hs b/src/LambdaCalculus/Types.hs new file mode 100644 index 0000000..97bacd0 --- /dev/null +++ b/src/LambdaCalculus/Types.hs @@ -0,0 +1,158 @@ +module LambdaCalculus.Types + ( module LambdaCalculus.Types.Base + , infer + ) where + +import LambdaCalculus.Types.Base + +import Control.Applicative ((<|>)) +import Control.Monad (when) +import Control.Monad.Except (MonadError, throwError) +import Control.Monad.Reader (MonadReader, runReader, asks, local) +import Control.Monad.RWS + ( RWST, evalRWST + , MonadState, state + , MonadWriter, tell, listen + ) +import Data.Foldable (forM_, toList) +import Data.HashSet qualified as HS +import Data.HashMap.Strict qualified as HM +import Data.Stream (Stream (..), fromList) +import Data.Text qualified as T + +fresh :: MonadState (Stream Text) m => m Type +fresh = state \(Cons n ns) -> (TVar n, ns) + +inst :: MonadState (Stream Text) m => Scheme -> m Type +inst (TForall ns t) = foldr (\n_n t' -> substitute1 n_n <$> fresh <*> t') (pure t) ns + +lookupVar :: (MonadReader Context m, MonadState (Stream Text) m, MonadError Text m) => Text -> m Type +lookupVar n = do + t_polyq <- asks (HM.!? n) + case t_polyq of + Nothing -> throwError $ "Variable not bound: " <> n + Just t_poly -> inst t_poly + +generalize :: MonadReader Context m => Type -> m Scheme +generalize t = do + ctx <- asks HM.keysSet + pure $ TForall (toList $ HS.difference (free t) ctx) t + +bindVar :: MonadReader Context m => Text -> Type -> m a -> m a +bindVar n t = local (HM.insert n (TForall [] t)) + +unify :: MonadWriter [Constraint] m => Type -> Type -> m () +unify t1 t2 = tell [(t1, t2)] + +ctrTy :: MonadState (Stream Text) m => Ctr -> m (Type, [Type]) +ctrTy = \case + CUnit -> pure (TUnit, []) + CZero -> pure (TNat, []) + CSucc -> pure (TNat, [TNat]) + CChar -> pure (TChar, [TNat]) + CNil -> mkUnary TList $ const [] + CCons -> mkUnary TList \t_a -> [t_a, TApp TList t_a] + CPair -> mkBinary TProd \t_a t_b -> [t_a, t_b] + CLeft -> mkBinary TSum \t_a _ -> [t_a] + CRight -> mkBinary TSum \_ t_b -> [t_b] + where + mkBinary tc tcas = do + t_a <- fresh + t_b <- fresh + pure (tapp [tc, t_a, t_b], tcas t_a t_b) + + mkUnary tc tcas = do + t_a <- fresh + pure (TApp tc t_a, tcas t_a) + +j :: (MonadError Text m, MonadReader Context m, MonadState (Stream Text) m, MonadWriter [Constraint] m) + => CheckExpr -> m Type +j (Var name) = lookupVar name +j (App e_fun e_arg) = do + t_ret <- fresh + t_fun <- j e_fun + t_arg <- j e_arg + unify t_fun (tapp [TAbs, t_arg, t_ret]) + pure t_ret +j (Abs n_arg e_ret) = do + t_arg <- fresh + t_ret <- bindVar n_arg t_arg $ j e_ret + pure $ tapp [TAbs, t_arg, t_ret] +j (Let (n_x, e_x) e_ret) = do + (t_x_mono, c) <- listen $ j e_x + s <- solve' c + t_x_poly <- generalize $ substitute s t_x_mono + local (HM.insert n_x t_x_poly) $ j e_ret +-- In a case expression: +-- * the pattern for each branch has the same type as the expression being matched, and +-- * the return type for each branch has the same type as the return type of the case expression as a whole. +j (Case ctrs) = do + t_ret <- fresh + t_x <- fresh + forM_ ctrs \(Pat ctr ns_n e) -> do + (t_x', ts_n) <- ctrTy ctr + unify t_x t_x' + when (length ts_n /= length ns_n) $ throwError "Constructor arity mismatch" + t_ret' <- local (HM.union $ HM.fromList $ zip ns_n $ map (TForall []) ts_n) $ j e + unify t_ret t_ret' + pure $ tapp [TAbs, t_x, t_ret] +j (CtrC ctr) = do + (t_ret, ts_n) <- ctrTy ctr + pure $ foldr (\t_a t_r -> tapp [TAbs, t_a, t_r]) t_ret ts_n +j CallCCC = do + t_a <- fresh + t_b <- fresh + pure $ tapp [TAbs, tapp [TAbs, tapp [TAbs, t_a, t_b], t_a], t_a] +j FixC = do + t_a <- fresh + pure $ tapp [TAbs, tapp [TAbs, t_a, t_a], t_a] +j HoleC = asks show >>= throwError . (<>) "Encountered hole with context: " . T.pack + +occurs :: Text -> Type -> Bool +occurs n t = HS.member n (free t) + +findDifference :: MonadError (Type, Type) m => Type -> Type -> m (Maybe (Text, Type)) +findDifference t1 t2 + | t1 == t2 = pure Nothing + | TVar n1 <- t1, not (occurs n1 t2) = pure $ Just (n1, t2) + | TVar _ <- t2 = findDifference t2 t1 + | TApp a1 b1 <- t1, TApp a2 b2 <- t2 = (<|>) <$> findDifference a1 a2 <*> findDifference b1 b2 + | otherwise = throwError (t1, t2) + +unifies :: MonadError (Type, Type) m => Type -> Type -> m Substitution +unifies t1 t2 = do + dq <- findDifference t1 t2 + case dq of + Nothing -> pure HM.empty + Just s -> do + ss <- unifies (uncurry substitute1 s t1) (uncurry substitute1 s t2) + pure $ uncurry HM.insert (fmap (substitute ss) s) ss + +solve :: MonadError (Type, Type) m => [Constraint] -> m Substitution +solve [] = pure HM.empty +solve (c:cs) = do + s <- uncurry unifies c + ss <- solve (substituteMono s cs) + pure $ HM.union ss (substituteMono ss s) + +solve' :: MonadError Text m => [Constraint] -> m Substitution +solve' c = case solve c of + Right ss -> pure ss + Left (t1, t2) -> throwError $ "Could not unify " <> unparseType t1 <> " with " <> unparseType t2 + +type Inferencer a = RWST Context [Constraint] (Stream Text) (Either Text) a + +runInferencer :: Inferencer a -> Either Text (a, [Constraint]) +runInferencer m = evalRWST m HM.empty freshNames + where + freshNames = fromList $ do + n <- [0 :: Int ..] + c <- ['a'..'z'] + pure $ T.pack if n == 0 then [c] else c : show n + +infer :: CheckExpr -> Either Text Scheme +infer e = do + (t, c) <- runInferencer $ j e + s <- solve' c + let t' = substitute s t + pure $ runReader (generalize t') HM.empty diff --git a/src/LambdaCalculus/Types/Base.hs b/src/LambdaCalculus/Types/Base.hs new file mode 100644 index 0000000..91f2494 --- /dev/null +++ b/src/LambdaCalculus/Types/Base.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE TemplateHaskell #-} +module LambdaCalculus.Types.Base + ( Identity (..) + , Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text + , substitute, substitute1, rename, rename1, free, bound, used + , Check, CheckExpr, CheckExprF, CheckX, CheckXF (..) + , pattern AppFC, pattern CtrC, pattern CtrFC, pattern CallCCC, pattern CallCCFC + , pattern FixC, pattern FixFC, pattern HoleC, pattern HoleFC + , Type (..), TypeF (..), Scheme (..), tapp + , Substitution, Context, Constraint + , MonoSubstitutable, substituteMono, substituteMono1 + , unparseType, unparseScheme + ) where + +import Control.Monad (forM) +import Control.Monad.Reader (asks) +import Data.Bifunctor (bimap, first) +import Data.Foldable (fold) +import Data.Functor.Foldable (embed, cata, para) +import Data.Functor.Foldable.TH (makeBaseFunctor) +import Data.Functor.Identity (Identity (..)) +import Data.HashMap.Strict (HashMap) +import Data.HashMap.Strict qualified as HM +import Data.List (foldl1') +import Data.Text qualified as T +import Data.Traversable (for) +import LambdaCalculus.Expression.Base + +data Check +type CheckExpr = Expr Check +type instance AppArgs Check = CheckExpr +type instance AbsArgs Check = Text +type instance LetArgs Check = (Text, CheckExpr) +type instance CtrArgs Check = UnitF CheckExpr +type instance XExpr Check = CheckX + +type CheckX = CheckXF CheckExpr + +type CheckExprF = ExprF Check +type instance AppArgsF Check = Identity +type instance LetArgsF Check = DefF +type instance CtrArgsF Check = UnitF +type instance XExprF Check = CheckXF + +data CheckXF r + -- | Call-with-current-continuation. + = CallCCC_ + -- | A fixpoint combinator, + -- because untyped lambda calculus fixpoint combinators won't typecheck. + | FixC_ + -- | A hole to ask the type inferencer about the context for debugging purposes. + | HoleC_ + deriving (Eq, Functor, Foldable, Traversable, Show) + +pattern CtrC :: Ctr -> CheckExpr +pattern CtrC c = Ctr c Unit + +pattern CtrFC :: Ctr -> CheckExprF r +pattern CtrFC c = CtrF c Unit + +pattern AppFC :: r -> r -> CheckExprF r +pattern AppFC ef ex = AppF ef (Identity ex) + +pattern CallCCC :: CheckExpr +pattern CallCCC = ExprX CallCCC_ + +pattern CallCCFC :: CheckExprF r +pattern CallCCFC = ExprXF CallCCC_ + +pattern FixC :: CheckExpr +pattern FixC = ExprX FixC_ + +pattern FixFC :: CheckExprF r +pattern FixFC = ExprXF FixC_ + +pattern HoleC :: CheckExpr +pattern HoleC = ExprX HoleC_ + +pattern HoleFC :: CheckExprF r +pattern HoleFC = ExprXF HoleC_ + +{-# COMPLETE Var, App, Abs, Let, CtrC, Case, ExprX #-} +{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFC, CaseF, ExprXF #-} +{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrF, CaseF, ExprXF #-} +{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrFC, CaseF, ExprXF #-} +{-# COMPLETE Var, App, Abs, Let, Ctr, Case, CallCCC, FixC, HoleC #-} +{-# COMPLETE Var, App, Abs, Let, CtrC, Case, CallCCC, FixC, HoleC #-} +{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFC, CaseF, CallCCFC, FixFC, HoleFC #-} +{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrF, CaseF, CallCCFC, FixFC, HoleFC #-} +{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrFC, CaseF, CallCCFC, FixFC, HoleFC #-} + +instance RecursivePhase Check where + projectAppArgs = Identity + projectLetArgs = projectDef + + embedAppArgs = runIdentity + embedLetArgs = embedDef + +instance Substitutable CheckExpr where + collectVars withVar withBinder = cata \case + VarF n -> withVar n + AbsF n e -> withBinder n e + LetF (Def n x) e -> x <> withBinder n e + CaseF pats -> foldMap (\(Pat _ ns e) -> foldr withBinder e ns) pats + e -> fold e + + rename = runRenamer $ \badNames -> cata \case + VarF n -> asks $ Var . HM.findWithDefault n n + AbsF n e -> uncurry Abs . first runIdentity <$> replaceNames badNames (Identity n) e + LetF (Def n x) e -> do + x' <- x + (Identity n', e') <- replaceNames badNames (Identity n) e + pure $ Let (n', x') e' + CaseF ps -> Case <$> forM ps \(Pat ctr ns e) -> + uncurry (Pat ctr) <$> replaceNames badNames ns e + e -> embed <$> sequenceA e + + unsafeSubstitute = runSubstituter $ para \case + VarF name -> asks $ HM.findWithDefault (Var name) name + AbsF name e -> Abs name <$> maySubstitute (Identity name) e + LetF (Def name (_, x)) e -> do + x' <- x + e' <- maySubstitute (Identity name) e + pure $ Let (name, x') e' + CaseF pats -> Case <$> for pats \(Pat ctr ns e) -> Pat ctr ns <$> maySubstitute ns e + e -> embed <$> traverse snd e + +-- | A monomorphic type. +data Type + -- | Type variable. + = TVar Text + -- | Type application. + | TApp Type Type + -- | The function type. + | TAbs + -- | The product type. + | TProd + -- | The sum type. + | TSum + -- | The unit type. + | TUnit + -- | The empty type. + | TVoid + -- | The type of natural numbers. + | TNat + -- | The type of lists. + | TList + -- | The type of characters. + | TChar + deriving (Eq, Show) + +makeBaseFunctor ''Type + +instance Substitutable Type where + collectVars withVar _ = cata \case + TVarF n -> withVar n + t -> fold t + + -- /All/ variables in a monomorphic type are free. + rename _ t = t + + -- No renaming step is necessary. + substitute substs = cata \case + TVarF n -> HM.findWithDefault (TVar n) n substs + e -> embed e + + unsafeSubstitute = substitute + +-- | A polymorphic type. +data Scheme + -- | Universally quantified type variables. + = TForall [Text] Type + deriving (Eq, Show) + +instance Substitutable Scheme where + collectVars withVar withBinder (TForall names t) = + foldMap withBinder names $ collectVars withVar withBinder t + + rename = runRenamer \badNames (TForall names t) -> + uncurry TForall <$> replaceNames badNames names (pure t) + + -- I took a shot at implementing this but found it to be quite difficult + -- because merging the foralls is tricky. + -- It's not undoable, but it wasn't worth my further time investment + -- seeing as this function isn't currently used anywhere. + unsafeSubstitute = error "Substitution for schemes not yet implemented" + +type Substitution = HashMap Text Type +type Context = HashMap Text Scheme +type Constraint = (Type, Type) + +class MonoSubstitutable t where + substituteMono :: Substitution -> t -> t + + substituteMono1 :: Text -> Type -> t -> t + substituteMono1 var val= substituteMono (HM.singleton var val) + +instance MonoSubstitutable Type where + substituteMono = substitute + +instance MonoSubstitutable Scheme where + substituteMono substs (TForall names t) = + TForall names $ substitute (foldMap HM.delete names substs) t + +instance MonoSubstitutable Constraint where + substituteMono substs = bimap (substituteMono substs) (substituteMono substs) + +instance MonoSubstitutable t => MonoSubstitutable [t] where + substituteMono = fmap . substituteMono + +instance MonoSubstitutable t => MonoSubstitutable (HashMap a t) where + substituteMono = fmap . substituteMono + +tapp :: [Type] -> Type +tapp [] = error "Empty type applications are not permitted" +tapp [t] = t +tapp ts = foldl1' TApp ts + +-- HACK +pattern TApp2 :: Type -> Type -> Type -> Type +pattern TApp2 tf tx ty = TApp (TApp tf tx) ty + +-- TODO: Improve these printers. +unparseType :: Type -> Text +unparseType (TVar name) = name +unparseType (TApp2 TAbs a b) = "(" <> unparseType a <> " -> " <> unparseType b <> ")" +unparseType (TApp2 TProd a b) = "(" <> unparseType a <> " * " <> unparseType b <> ")" +unparseType (TApp2 TSum a b) = "(" <> unparseType a <> " + " <> unparseType b <> ")" +unparseType (TApp TList a) = "[" <> unparseType a <> "]" +unparseType (TApp a b) = "(" <> unparseType a <> " " <> unparseType b <> ")" +unparseType TAbs = "(->)" +unparseType TProd = "(*)" +unparseType TSum = "(+)" +unparseType TUnit = "★" +unparseType TVoid = "⊥" +unparseType TNat = "Nat" +unparseType TList = "[]" +unparseType TChar = "Char" + +unparseScheme :: Scheme -> Text +unparseScheme (TForall [] t) = unparseType t +unparseScheme (TForall names t) = "∀" <> T.unwords names <> ". " <> unparseType t diff --git a/test/Spec.hs b/test/Spec.hs index 444409f..cc30b39 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -39,10 +39,10 @@ omega = App x x where x = Abs "x" (App (Var "x") (Var "x")) cc1 :: EvalExpr -cc1 = App CallCC (Abs "k" (App omega (App (Var "k") (Var "z")))) +cc1 = App CallCCE (Abs "k" (App omega (App (Var "k") (Var "z")))) cc2 :: EvalExpr -cc2 = App (Var "y") (App CallCC (Abs "k" (App (Var "z") (App (Var "k") (Var "x"))))) +cc2 = App (Var "y") (App CallCCE (Abs "k" (App (Var "z") (App (Var "k") (Var "x"))))) main :: IO () main = defaultMain $