Added support for Hindley-Milner type inference!

I also generalized substitution to *all* expressions (and types),
which itself involved a rewrite of the substitution function I already had.
(In hindsight, I wish I'd done that in a separate commit.)

The last four days have all been working up to this,
so I'm glad I've finally managed to integrate it into this project!
(I wrote the algorithm 4 days ago, but the infrastructure
just wasn't there to add it here.)
master
James T. Martin 2021-03-18 00:00:43 -07:00
parent 586be18c80
commit 9e0754daf6
Signed by: james
GPG Key ID: 4B7F3DA9351E577C
15 changed files with 806 additions and 254 deletions

View File

@ -1,11 +1,14 @@
# Lambda Calculus # Lambda Calculus
This is a simple programming language derived from lambda calculus. This is a simple programming language derived from lambda calculus,
using the Hindley-Milner type system, plus some builtin types, `fix`, and `callcc`
## Usage ## Usage
Run the program using `stack run` (or run the tests with `stack test`). Run the program using `stack run` (or run the tests with `stack test`).
Type in your expression at the prompt: `>> `. Type in your expression at the prompt: `>> `. This will happen:
The expression will be evaluated to normal form using the call-by-value evaluation strategy and then printed. * the type for the expression will be inferred and then printed,
* then, regardless of whether typechecking succeeded, expression will be evaluated to normal form using the call-by-value evaluation strategy and then printed.
Exit the prompt with `Ctrl-c` (or equivalent). Exit the prompt with `Ctrl-c` (or equivalent).
## Syntax ## Syntax
@ -16,23 +19,38 @@ The parser's error messages currently are virtually useless, so be very careful
* Lambda abstraction: `\x y z. E` or `λx y z. E` * Lambda abstraction: `\x y z. E` or `λx y z. E`
* Let expressions: `let x = E; y = F in G` * Let expressions: `let x = E; y = F in G`
* Parenthetical expressions: `(E)` * Parenthetical expressions: `(E)`
* Constructors: `()`, `(x, y)` (or `(,) x y`), `Left x`, `Right y`, `Z`, `S`, `[]`, `(x : xs)` (or `(:) x xs`), `Char n`. * Constructors: `()`, `(x, y)` (or `(,) x y`), `Left x`, `Right y`, `Z`, `S`, `[]`, `(x :: xs)` (or `(:) x xs`), `Char n`.
* The parentheses around the cons constructor are not optional. * The parentheses around the cons constructor are not optional.
* `Char` takes a natural number and turns it into a character. * `Char` takes a natural number and turns it into a character.
* Pattern matchers: `{ Left x -> e ; Right y -> f }` * Pattern matchers: `case { Left a -> e ; Right y -> f }`
* Pattern matchers can be applied like functions, e.g. `{ Z -> x, S -> y } 10` reduces to `y`. * Pattern matchers can be applied like functions, e.g. `{ Z -> x, S -> y } 10` reduces to `y`.
* Patterns must use the regular form of the constructor, e.g. `(x : xs)` and not `((:) x xs)`. * Patterns must use the regular form of the constructor, e.g. `(x :: xs)` and not `((::) x xs)`.
* There are no nested patterns or default patterns. * There are no nested patterns or default patterns.
* Incomplete pattern matches will crash the interpreter. * Incomplete pattern matches will crash the interpreter.
* Literals: `1234`, `[e, f, g, h]`, `'a`, `"abc"` * Literals: `1234`, `[e, f, g, h]`, `'a`, `"abc"`
* Strings are represented as lists of characters. * Strings are represented as lists of characters.
* Type annotations: there are no type annotations; types are inferred only.
## Call/CC ## Types
This interpreter has preliminary support for Types are checked/inferred using the Hindley-Milner type inference algorithm.
[the call-with-current-continuation control flow operator](https://en.wikipedia.org/wiki/Call-with-current-continuation).
However, it has not been thoroughly tested.
To use it, simply apply the variable `callcc` like you would a function, e.g. `(callcc (\k. ...))`. * Functions: `a -> b` (constructed by `\x. e`)
* Products: `a * b` (constructed by `(x, y)`)
* Unit: `★` (constructed by `()`)
* Sums: `a + b` (constructed by `Left x` or `Right y`)
* Bottom: `⊥` (currently useless because incomplete patterns are allowed)
* The natural numbers: `Nat` (constructed by `Z` and `S`)
* Lists: `List a` (constructed by `[]` and `(x :: xs)`)
* Characters: `Char` (constructed by `Char`, which takes a `Nat`)
* Universal quantification (forall): `∀a b. t`
## Builtins
Builtins are variables that correspond with a built-in language feature
that cannot be replicated by user-written code.
They still are just variables though; they do not receive special syntactic treatment.
* `fix : ∀a. ((a -> a) -> a)`: an alias for the strict fixpoint combinator that the typechecker can typecheck.
* `callcc : ∀a b. (((a -> b) -> a) -> a)`: [the call-with-current-continuation control flow operator](https://en.wikipedia.org/wiki/Call-with-current-continuation).
Continuations are printed as `λ!. ... ! ...`, like a lambda abstraction Continuations are printed as `λ!. ... ! ...`, like a lambda abstraction
with an argument named `!` which is used exactly once; with an argument named `!` which is used exactly once;
@ -41,56 +59,61 @@ because they perform the side effect of modifying the current continuation,
and this is *not* valid syntax you can input into the REPL. and this is *not* valid syntax you can input into the REPL.
## Example code ## Example code
The fixpoint function:
```
(\x. x x) \fix f x. f (fix fix f) x
```
Create a list by iterating `f` `n` times: Create a list by iterating `f` `n` times:
``` ```
fix \iterate f x. { Z -> x ; S n -> iterate f (f x) n } fix \iterate f x. { Z -> [] ; S n -> (x :: iterate f (f x) n) }
: ∀e. ((e -> e) -> (e -> (Nat -> [e])))
``` ```
Create a list whose first element is `n - 1`, counting down to a last element of `0`: Use the iterate function to count to 10:
``` ```
\n. { (n, x) -> x } (iterate { (n, x) -> (S n, (n : x)) } (0, []) n) >> let iterate = fix \iterate f x. { Z -> [] ; S n -> (x :: iterate f (f x) n) }; countTo = iterate S 1 in countTo 10
``` : [Nat]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Putting it all together to count down from 10:
```
>> let fix = (\x. x x) \fix f x. f (fix fix f) x; iterate = fix \iterate f x. { Z -> x ; S n -> iterate f (f x) n }; countDownFrom = \n. { (n, x) -> x } (iterate { (n, x) -> (S n, (n : x)) } (0, []) n) in countDownFrom 10
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
``` ```
Append two lists together: Append two lists together:
``` ```
fix \append xs ys. { [] -> ys ; (x : xs) -> (x : append xs ys) } xs fix \append xs ys. { [] -> ys ; (x :: xs) -> (x :: append xs ys) } xs
: ∀j. ([j] -> ([j] -> [j]))
``` ```
Reverse a list: Reverse a list:
``` ```
fix \reverse. { [] -> [] ; (x : xs) -> append (reverse xs) [x] } fix \reverse. { [] -> [] ; (x :: xs) -> append (reverse xs) [x] }
: ∀c1. ([c1] -> [c1])
``` ```
Putting them together so we can reverse `"reverse"`: Putting them together so we can reverse `"reverse"`:
``` ```
>> let fix = (\x. x x) \fix f x. f (fix fix f) x; append = fix \append xs ys. { [] -> ys ; (x : xs) -> (x : append xs ys) } xs; reverse = fix \reverse. { [] -> [] ; (x : xs) -> append (reverse xs) [x] } in reverse "reverse" >> let append = fix \append xs ys. { [] -> ys ; (x :: xs) -> (x :: append xs ys) } xs; reverse = fix \reverse. { [] -> [] ; (x :: xs) -> append (reverse xs) [x] } in reverse "reverse"
: [Char]
"esrever" "esrever"
``` ```
Calculating `3 + 2` with the help of Church-encoded numerals: Calculating `3 + 2` with the help of Church-encoded numerals:
``` ```
>> let Sf = \n f x. f (n f x); plus = \x. x Sf in plus (\f x. f (f (f x))) (\f x. f (f x)) S Z >> let Sf = \n f x. f (n f x); plus = \x. x Sf in plus (\f x. f (f (f x))) (\f x. f (f x)) S Z
: Nat
5 5
``` ```
This expression would loop forever, but `callcc` saves the day! This expression would loop forever, but `callcc` saves the day!
``` ```
>> y (callcc \k. (\x. (\x. x x) (\x. x x)) (k z)) >> S (callcc \k. (fix \x. x) (k Z))
y z : Nat
1
``` ```
A few other weird expressions: And if it wasn't clear, this is what the `Char` constructor does:
```
>> { Char c -> Char (S c) } 'a
: Char
'b
```
Here are a few expressions which don't typecheck but are handy for debugging the evaluator:
``` ```
>> let D = \x. x x; F = \f. f (f y) in D (F \x. x) >> let D = \x. x x; F = \f. f (f y) in D (F \x. x)
y y y y
@ -98,6 +121,4 @@ y y
y y
>> (\x y z. x y) y >> (\x y z. x y) y
λy' z. y y' λy' z. y y'
>> { Char c -> Char (S c) } 'a
'b
``` ```

View File

@ -15,8 +15,10 @@ prompt text = do
getLine getLine
main :: IO () main :: IO ()
main = forever $ parseEval <$> prompt ">> " >>= \case main = forever $ parseCheck <$> prompt ">> " >>= \case
Left parseError -> putStrLn $ "Parse error: " <> pack (show parseError) Left parseError -> putStrLn $ "Parse error: " <> pack (show parseError)
-- TODO: Support choosing which version to use at runtime. -- TODO: Support choosing which version to use at runtime.
Right expr -> putStrLn $ unparseEval $ eval expr Right expr -> do
--Right expr -> mapM_ (putStrLn . unparseEval) $ snd $ traceEval expr either putStrLn (putStrLn . (": " <>) . unparseScheme) $ infer expr
putStrLn $ unparseEval $ eval $ check2eval expr
--mapM_ (putStrLn . unparseEval) $ snd $ traceEval $ check2eval expr

View File

@ -20,9 +20,11 @@ default-extensions:
- FlexibleContexts - FlexibleContexts
- FlexibleInstances - FlexibleInstances
- ImportQualifiedPost - ImportQualifiedPost
- InstanceSigs
- LambdaCase - LambdaCase
- OverloadedStrings - OverloadedStrings
- PatternSynonyms - PatternSynonyms
- ScopedTypeVariables
- StandaloneDeriving - StandaloneDeriving
- ViewPatterns - ViewPatterns
# Required for use of the 'trees that grow' pattern # Required for use of the 'trees that grow' pattern
@ -32,7 +34,6 @@ default-extensions:
- DeriveFoldable - DeriveFoldable
- DeriveFunctor - DeriveFunctor
- DeriveTraversable - DeriveTraversable
- TemplateHaskell
dependencies: dependencies:
- base >= 4.14 && < 5 - base >= 4.14 && < 5

View File

@ -2,15 +2,23 @@ module LambdaCalculus
( module LambdaCalculus.Evaluator ( module LambdaCalculus.Evaluator
, module LambdaCalculus.Expression , module LambdaCalculus.Expression
, module LambdaCalculus.Syntax , module LambdaCalculus.Syntax
, parseEval, unparseEval , module LambdaCalculus.Types
, parseCheck, parseEval, unparseCheck, unparseEval
) where ) where
import LambdaCalculus.Evaluator import LambdaCalculus.Evaluator
import LambdaCalculus.Expression import LambdaCalculus.Expression
import LambdaCalculus.Syntax import LambdaCalculus.Syntax
import LambdaCalculus.Types
parseCheck :: Text -> Either ParseError CheckExpr
parseCheck = fmap ast2check . parseAST
parseEval :: Text -> Either ParseError EvalExpr parseEval :: Text -> Either ParseError EvalExpr
parseEval = fmap ast2eval . parseAST parseEval = fmap ast2eval . parseAST
unparseCheck :: CheckExpr -> Text
unparseCheck = unparseAST . simplify . check2ast
unparseEval :: EvalExpr -> Text unparseEval :: EvalExpr -> Text
unparseEval = unparseAST . simplify . eval2ast unparseEval = unparseAST . simplify . eval2ast

View File

@ -2,130 +2,29 @@ module LambdaCalculus.Evaluator
( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text ( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text
, Eval, EvalExpr, EvalX, EvalXF (..) , Eval, EvalExpr, EvalX, EvalXF (..)
, pattern AppFE, pattern CtrE, pattern CtrFE , pattern AppFE, pattern CtrE, pattern CtrFE
, pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF , pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE
, eval, traceEval, substitute, alphaConvert , eval, traceEval
) where ) where
import LambdaCalculus.Evaluator.Base import LambdaCalculus.Evaluator.Base
import LambdaCalculus.Evaluator.Continuation import LambdaCalculus.Evaluator.Continuation
import Control.Monad (forM)
import Control.Monad.Except (MonadError, ExceptT, throwError, runExceptT) import Control.Monad.Except (MonadError, ExceptT, throwError, runExceptT)
import Control.Monad.State (MonadState, State, evalState, modify', state, put, gets) import Control.Monad.State (MonadState, evalState, modify', state, put, gets)
import Control.Monad.Writer (runWriterT, tell) import Control.Monad.Writer (runWriterT, tell)
import Data.Foldable (fold) import Data.HashMap.Strict qualified as HM
import Data.Functor.Foldable (cata, para, embed)
import Data.HashSet (HashSet)
import Data.HashSet qualified as HS
import Data.Stream qualified as S
import Data.Text qualified as T
import Data.Void (Void, absurd) import Data.Void (Void, absurd)
-- | Free variables are variables which are present in an expression but not bound by any abstraction.
freeVars :: EvalExpr -> HashSet Text
freeVars = cata \case
VarF n -> HS.singleton n
AbsF n e -> HS.delete n e
ContF e -> HS.delete "!" e
CaseF ps -> foldMap (\(Pat _ ns e) -> HS.difference e (HS.fromList ns)) ps
e -> fold e
-- | Bound variables are variables which are bound by any form of abstraction in an expression.
boundVars :: EvalExpr -> HashSet Text
boundVars = cata \case
AbsF n e -> HS.insert n e
ContF e -> HS.insert "!" e
CaseF ps -> foldMap (\(Pat _ ns e) -> HS.union (HS.fromList ns) e) ps
e -> fold e
-- | Vars that occur anywhere in an experession, bound or free.
usedVars :: EvalExpr -> HashSet Text
usedVars x = HS.union (freeVars x) (boundVars x)
-- | Substitution is the process of replacing all free occurrences of a variable in one expression with another expression.
substitute :: Text -> EvalExpr -> EvalExpr -> EvalExpr
substitute var val = unsafeSubstitute var val . alphaConvert (freeVars val)
-- | Substitution is only safe if the bound variables in the body
-- are disjoint from the free variables in the argument;
-- this function makes an expression body safe for substitution
-- by replacing the bound variables in the body
-- with completely new variables which do not occur in either expression
-- (without changing any *free* variables in the body, of course).
alphaConvert :: HashSet Text -> EvalExpr -> EvalExpr
alphaConvert ctx e_ = evalState (alphaConverter e_) $ HS.union ctx (usedVars e_)
where
alphaConverter :: EvalExpr -> State (HashSet Text) EvalExpr
alphaConverter = cata \case
e
| AbsF n e' <- e, n `HS.member` ctx -> do
n' <- fresh n
e'' <- e'
pure $ Abs n' $ replace n n' e''
-- | TODO: Only replace the names that *have* to be replaced.
| CaseF ps <- e, any (any (`HS.member` ctx) . patNames) ps ->
Case <$> forM ps \(Pat ctr ns e') -> do
ns' <- mapM fresh ns
e'' <- e'
pure $ Pat ctr ns' $ foldr (uncurry replace) e'' (zip ns ns')
| otherwise -> embed <$> sequenceA e
-- | Create a new name which is not used anywhere else.
fresh :: Text -> State (HashSet Text) Text
fresh n = state \ctx' ->
let n' = S.head $ S.filter (not . (`HS.member` ctx')) names
in (n', HS.insert n' ctx')
where names = S.iterate (`T.snoc` '\'') n
-- | Replace a name with an entirely new name in all contexts.
-- This will only give correct results if
-- the new name does not occur anywhere in the expression.
replace :: Text -> Text -> EvalExpr -> EvalExpr
replace name name' = cata \case
e
| VarF name2 <- e, name == name2 -> Var name'
| AbsF name2 e' <- e, name == name2 -> Abs name' e'
| CaseF ps <- e -> Case $ flip map ps \(Pat ctr ns e') -> Pat ctr (replace' ns) e'
| otherwise -> embed e
where
replace' = map \case
n
| n == name -> name'
| otherwise -> n
-- | Substitution which does *not* avoid variable capture;
-- it only gives the correct result if the bound variables in the body
-- are disjoint from the free variables in the argument.
unsafeSubstitute :: Text -> EvalExpr -> EvalExpr -> EvalExpr
unsafeSubstitute var val = para \case
e'
| VarF var2 <- e', var == var2 -> val
| AbsF var2 _ <- e', var == var2 -> unmodified e'
| ContF _ <- e', var == "!" -> unmodified e'
| CaseF ps <- e' -> Case $ flip map ps \(Pat ctr ns (unmod, sub)) ->
Pat ctr ns if var `elem` ns then unmod else sub
| otherwise -> substituted e'
where
substituted, unmodified :: EvalExprF (EvalExpr, EvalExpr) -> EvalExpr
substituted = embed . fmap snd
unmodified = embed . fmap fst
isReducible :: EvalExpr -> Bool isReducible :: EvalExpr -> Bool
isReducible = snd . cata \case -- Applications of function type constructors
AppFE ctr args -> active ctr [args] isReducible (App (Abs _ _) _) = True
AbsF _ _ -> passive isReducible (App (ContE _) _) = True
ContF _ -> passive isReducible (App CallCCE _) = True
CaseF _ -> passive -- Pattern matching of data
CallCCF -> passive isReducible (App (Case _) ex) = isData ex || isReducible ex
CtrFE _ -> constant -- Reducible subexpressions
VarF _ -> constant isReducible (App ef ex) = isReducible ef || isReducible ex
where isReducible _ = False
-- | Constants are irreducible in any context.
constant = (False, False)
-- | Passive expressions are reducible only if an active expression is applied to them.
passive = (True, False)
-- | Active expressions are reducible if they are applied to a constructor or their arguments are reducible.
active ctr args = (False, fst ctr || snd ctr || any snd args)
lookupPat :: Ctr -> [Pat phase] -> Pat phase lookupPat :: Ctr -> [Pat phase] -> Pat phase
lookupPat ctr = foldr lookupCtr' (error "Constructor not found") lookupPat ctr = foldr lookupCtr' (error "Constructor not found")
@ -182,21 +81,22 @@ evaluatorStep = \case
| otherwise -> case ef of | otherwise -> case ef of
-- perform beta reduction if possible... -- perform beta reduction if possible...
Abs name body -> Abs name body ->
pure $ substitute name ex body pure $ substitute1 name ex body
Case pats Case pats ->
| isData ex -> do if isData ex
let (ctr, xs) = toData ex then do
let Pat _ ns e = lookupPat ctr pats let (ctr, xs) = toData ex
pure $ foldr (uncurry substitute) e (zip ns xs) let Pat _ ns e = lookupPat ctr pats
| otherwise -> ret unmodified pure $ substitute (HM.fromList $ zip ns xs) e
else ret unmodified
-- perform continuation calls if possible... -- perform continuation calls if possible...
Cont body -> do ContE body -> do
put [] put []
pure $ substitute "!" ex body pure $ substitute1 "!" ex body
-- capture the current continuation if requested... -- capture the current continuation if requested...
CallCC -> do CallCCE -> do
k <- gets $ continue (Var "!") k <- gets $ continue (Var "!")
pure $ App ex (Cont k) pure $ App ex (ContE k)
-- otherwise the value is irreducible and we can continue evaluation. -- otherwise the value is irreducible and we can continue evaluation.
_ -> ret unmodified _ -> ret unmodified
-- Neither abstractions, constructors nor variables are reducible. -- Neither abstractions, constructors nor variables are reducible.

View File

@ -1,14 +1,22 @@
module LambdaCalculus.Evaluator.Base module LambdaCalculus.Evaluator.Base
( Identity (..) ( Identity (..)
, Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text , Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text
, substitute, substitute1, rename, rename1, free, bound, used
, Eval, EvalExpr, EvalExprF, EvalX, EvalXF (..) , Eval, EvalExpr, EvalExprF, EvalX, EvalXF (..)
, pattern AppFE, pattern CtrE, pattern CtrFE , pattern AppFE, pattern CtrE, pattern CtrFE
, pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF , pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE
) where ) where
import LambdaCalculus.Expression.Base import LambdaCalculus.Expression.Base
import Control.Monad (forM)
import Control.Monad.Reader (asks)
import Data.Bifunctor (first)
import Data.Foldable (fold)
import Data.Functor.Identity (Identity (..)) import Data.Functor.Identity (Identity (..))
import Data.Functor.Foldable (embed, cata, para)
import Data.HashMap.Strict qualified as HM
import Data.Traversable (for)
data Eval data Eval
type EvalExpr = Expr Eval type EvalExpr = Expr Eval
@ -33,41 +41,62 @@ data EvalXF r
-- --
-- Continuations do not have any corresponding surface-level syntax, -- Continuations do not have any corresponding surface-level syntax,
-- but may be printed like a lambda with the illegal variable `!`. -- but may be printed like a lambda with the illegal variable `!`.
= Cont_ !r = ContE_ !r
-- | Call-with-current-continuation, an evaluator built-in function. -- | Call-with-current-continuation, an evaluator built-in function.
| CallCC_ | CallCCE_
deriving (Eq, Functor, Foldable, Traversable, Show) deriving (Eq, Functor, Foldable, Traversable, Show)
instance RecursivePhase Eval where
projectAppArgs = Identity
embedAppArgs = runIdentity
pattern CtrE :: Ctr -> EvalExpr pattern CtrE :: Ctr -> EvalExpr
pattern CtrE c = Ctr c Unit pattern CtrE c = Ctr c Unit
pattern CtrFE :: Ctr -> EvalExprF r pattern CtrFE :: Ctr -> EvalExprF r
pattern CtrFE c = CtrF c Unit pattern CtrFE c = CtrF c Unit
pattern Cont :: EvalExpr -> EvalExpr pattern ContE :: EvalExpr -> EvalExpr
pattern Cont e = ExprX (Cont_ e) pattern ContE e = ExprX (ContE_ e)
pattern CallCC :: EvalExpr pattern CallCCE :: EvalExpr
pattern CallCC = ExprX CallCC_ pattern CallCCE = ExprX CallCCE_
pattern ContF :: r -> EvalExprF r pattern ContFE :: r -> EvalExprF r
pattern ContF e = ExprXF (Cont_ e) pattern ContFE e = ExprXF (ContE_ e)
pattern CallCCF :: EvalExprF r pattern CallCCFE :: EvalExprF r
pattern CallCCF = ExprXF CallCC_ pattern CallCCFE = ExprXF CallCCE_
pattern AppFE :: r -> r -> EvalExprF r pattern AppFE :: r -> r -> EvalExprF r
pattern AppFE ef ex = AppF ef (Identity ex) pattern AppFE ef ex = AppF ef (Identity ex)
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, Cont, CallCC #-} {-# COMPLETE Var, App, Abs, Let, Ctr, Case, ContE, CallCCE #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrF, CaseF, ContF, CallCCF #-} {-# COMPLETE VarF, AppF, AbsF, LetF, CtrF, CaseF, ContFE, CallCCFE #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ExprXF #-} {-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ContF, CallCCF #-} {-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ContFE, CallCCFE #-}
{-# COMPLETE Var, App, Abs, Let, CtrE, Case, Cont, CallCC #-} {-# COMPLETE Var, App, Abs, Let, CtrE, Case, ContE, CallCCE #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFE, CaseF, ContF, CallCCF #-} {-# COMPLETE VarF, AppF, AbsF, LetF, CtrFE, CaseF, ContFE, CallCCFE #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ExprXF #-} {-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ContF, CallCCF #-} {-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ContFE, CallCCFE #-}
instance RecursivePhase Eval where
projectAppArgs = Identity
embedAppArgs = runIdentity
instance Substitutable EvalExpr where
collectVars withVar withBinder = cata \case
VarF n -> withVar n
AbsF n e -> withBinder n e
CaseF pats -> foldMap (\(Pat _ ns e) -> foldr withBinder e ns) pats
e -> fold e
rename = runRenamer $ \badNames -> cata \case
VarF n -> asks $ Var . HM.findWithDefault n n
AbsF n e -> uncurry Abs . first runIdentity <$> replaceNames badNames (Identity n) e
ContFE e -> ContE <$> e
CaseF ps -> Case <$> forM ps \(Pat ctr ns e) -> uncurry (Pat ctr) <$> replaceNames badNames ns e
e -> embed <$> sequenceA e
unsafeSubstitute = runSubstituter $ para \case
VarF name -> asks $ HM.findWithDefault (Var name) name
AbsF name e -> Abs name <$> maySubstitute (Identity name) e
ContFE e -> ContE <$> maySubstitute (Identity "!") e
CaseF pats -> Case <$> for pats \(Pat ctr ns e) -> Pat ctr ns <$> maySubstitute ns e
e -> embed <$> traverse snd e

View File

@ -1,6 +1,6 @@
module LambdaCalculus.Evaluator.Continuation module LambdaCalculus.Evaluator.Continuation
( Continuation, continue, continue1 ( Continuation, continue, continue1
, ContinuationCrumb (ApplyTo, AppliedTo, AbstractedOver) , ContinuationCrumb (..)
) where ) where
import LambdaCalculus.Evaluator.Base import LambdaCalculus.Evaluator.Base

View File

@ -1,46 +1,94 @@
module LambdaCalculus.Expression module LambdaCalculus.Expression
( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), DefF (..), VoidF, UnitF (..), Text ( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), DefF (..), VoidF, UnitF (..), Text
, substitute, substitute1, rename, free, bound, used
, Eval, EvalExpr, EvalX, EvalXF (..), Identity (..) , Eval, EvalExpr, EvalX, EvalXF (..), Identity (..)
, pattern AppFE, pattern CtrE, pattern CtrFE, , pattern AppFE, pattern CtrE, pattern CtrFE,
pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE
, Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..), NonEmpty (..), simplify , Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..), NonEmpty (..), simplify
, pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF , pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF
, pattern PChar, pattern PCharF, pattern PStr, pattern PStrF , pattern PChar, pattern PCharF, pattern PStr, pattern PStrF, pattern HoleP, pattern HoleFP
, ast2eval, eval2ast , Check, CheckExpr, CheckExprF, CheckX, CheckXF (..)
, pattern AppFC, pattern CtrC, pattern CtrFC, pattern CallCCC, pattern CallCCFC
, pattern FixC, pattern FixFC, pattern HoleC, pattern HoleFC
, Type (..), TypeF (..), Scheme (..), tapp
, ast2check, ast2eval, check2eval, check2ast, eval2ast
) where ) where
import LambdaCalculus.Evaluator.Base import LambdaCalculus.Evaluator.Base
import LambdaCalculus.Evaluator (alphaConvert, substitute)
import LambdaCalculus.Syntax.Base import LambdaCalculus.Syntax.Base
import LambdaCalculus.Types.Base
import Data.Functor.Foldable (cata, hoist) import Data.Functor.Foldable (cata, hoist)
import Data.HashSet qualified as HS import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.List (foldl') import Data.List (foldl')
import Data.List.NonEmpty (toList) import Data.List.NonEmpty (toList)
import Data.Text (unpack) import Data.Text (unpack)
-- | Convert from an abstract syntax tree to an evaluator expression. builtins :: HashMap Text CheckExpr
ast2eval :: AST -> EvalExpr builtins = HM.fromList [("callcc", CallCCC), ("fix", FixC)]
ast2eval = substitute "callcc" CallCC . cata \case
-- | Convert from an abstract syntax tree to a typechecker expression.
ast2check :: AST -> CheckExpr
ast2check = substitute builtins . cata \case
VarF name -> Var name VarF name -> Var name
AppF ef exs -> foldl' App ef $ toList exs AppF ef exs -> foldl' App ef $ toList exs
AbsF ns e -> foldr Abs e $ toList ns AbsF ns e -> foldr Abs e $ toList ns
LetF ds e -> LetF ds e ->
let letExpr name val body' = App (Abs name body') val let letExpr name val body' = App (Abs name body') val
in foldr (uncurry letExpr) e $ getNonEmptyDefFs ds in foldr (uncurry letExpr) e $ getNonEmptyDefFs ds
CtrF ctr es -> foldl' App (CtrE ctr) es CtrF ctr es -> foldl' App (CtrC ctr) es
CaseF ps -> Case ps CaseF ps -> Case ps
PNatF n -> int2ast n PNatF n -> int2ast n
PListF es -> mkList es PListF es -> mkList es
PStrF s -> mkList $ map (App (CtrE CChar) . int2ast . fromEnum) $ unpack s PStrF s -> mkList $ map (App (CtrC CChar) . int2ast . fromEnum) $ unpack s
PCharF c -> App (CtrE CChar) (int2ast $ fromEnum c) PCharF c -> App (CtrC CChar) (int2ast $ fromEnum c)
HoleFP -> HoleC
where where
int2ast :: Int -> EvalExpr int2ast :: Int -> CheckExpr
int2ast 0 = CtrE CZero int2ast 0 = CtrC CZero
int2ast n = App (CtrE CSucc) (int2ast (n - 1)) int2ast n = App (CtrC CSucc) (int2ast (n - 1))
mkList :: [EvalExpr] -> EvalExpr mkList :: [CheckExpr] -> CheckExpr
mkList = foldr (App . App (CtrE CCons)) (CtrE CNil) mkList = foldr (App . App (CtrC CCons)) (CtrC CNil)
-- | Convert from a typechecker expression to an evaluator expression.
check2eval :: CheckExpr -> EvalExpr
check2eval = cata \case
VarF name -> Var name
AppFC ef ex -> App ef ex
AbsF n e -> Abs n e
LetF (Def nx ex) e -> App (Abs nx e) ex
CtrFC ctr -> CtrE ctr
CaseF ps -> Case ps
CallCCFC -> CallCCE
FixFC -> z
HoleFC -> omega
where
z, omega :: EvalExpr
z = App omega $ Abs "fix" $ Abs "f" $ Abs "x" $
App (App (Var "f") (App (App (Var "fix") (Var "fix")) (Var "f"))) (Var "x")
omega = Abs "x" (App (Var "x") (Var "x"))
-- | Convert from an abstract syntax tree to an evaluator expression.
ast2eval :: AST -> EvalExpr
ast2eval = check2eval . ast2check
-- | Convert from a typechecker expression to an abstract syntax tree.
check2ast :: CheckExpr -> AST
check2ast = hoist go . rename (HM.keysSet builtins)
where
go :: CheckExprF r -> ASTF r
go = \case
VarF name -> VarF name
AppFC ef ex -> AppF ef (ex :| [])
AbsF n e -> AbsF (n :| []) e
LetF (Def nx ex) e -> LetFP ((nx, ex) :| []) e
CtrFC ctr -> CtrF ctr []
CaseF ps -> CaseF ps
CallCCFC-> VarF "callcc"
FixFC -> VarF "fix"
HoleFC -> HoleFP
-- | Convert from an evaluator expression to an abstract syntax tree. -- | Convert from an evaluator expression to an abstract syntax tree.
eval2ast :: EvalExpr -> AST eval2ast :: EvalExpr -> AST
@ -48,14 +96,14 @@ eval2ast :: EvalExpr -> AST
-- all instances of `callcc` must be bound; -- all instances of `callcc` must be bound;
-- therefore, we are free to alpha convert them, -- therefore, we are free to alpha convert them,
-- freeing the name `callcc` for us to use for the built-in again. -- freeing the name `callcc` for us to use for the built-in again.
eval2ast = hoist go . alphaConvert (HS.singleton "callcc") eval2ast = hoist go . rename (HM.keysSet builtins)
where where
go :: EvalExprF r -> ASTF r go :: EvalExprF r -> ASTF r
go = \case go = \case
VarF name -> VarF name VarF name -> VarF name
CallCCF -> VarF "callcc"
AppFE ef ex -> AppF ef (ex :| []) AppFE ef ex -> AppF ef (ex :| [])
AbsF n e -> AbsF (n :| []) e AbsF n e -> AbsF (n :| []) e
CtrFE ctr -> CtrF ctr [] CtrFE ctr -> CtrF ctr []
CaseF ps -> CaseF ps CaseF ps -> CaseF ps
ContF e -> AbsF ("!" :| []) e CallCCFE -> VarF "callcc"
ContFE e -> AbsF ("!" :| []) e

View File

@ -1,3 +1,4 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableInstances #-}
module LambdaCalculus.Expression.Base module LambdaCalculus.Expression.Base
( Text, VoidF, UnitF (..), absurd' ( Text, VoidF, UnitF (..), absurd'
@ -5,11 +6,24 @@ module LambdaCalculus.Expression.Base
, ExprF (..), PatF (..), DefF (..), AppArgsF, LetArgsF, CtrArgsF, XExprF , ExprF (..), PatF (..), DefF (..), AppArgsF, LetArgsF, CtrArgsF, XExprF
, RecursivePhase, projectAppArgs, projectLetArgs, projectCtrArgs, projectXExpr, projectDef , RecursivePhase, projectAppArgs, projectLetArgs, projectCtrArgs, projectXExpr, projectDef
, embedAppArgs, embedLetArgs, embedCtrArgs, embedXExpr, embedDef , embedAppArgs, embedLetArgs, embedCtrArgs, embedXExpr, embedDef
, Substitutable, free, bound, used, collectVars, rename, rename1
, substitute, substitute1, unsafeSubstitute, unsafeSubstitute1
, runRenamer, freshVar, replaceNames, runSubstituter, maySubstitute
) where ) where
import Control.Monad.Reader (MonadReader, Reader, runReader, asks, local)
import Control.Monad.State (MonadState, StateT, evalStateT, state)
import Control.Monad.Zip (MonadZip, mzipWith)
import Data.Foldable (fold)
import Data.Functor.Foldable (Base, Recursive, Corecursive, project, embed) import Data.Functor.Foldable (Base, Recursive, Corecursive, project, embed)
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.HashSet (HashSet)
import Data.HashSet qualified as HS
import Data.Kind (Type) import Data.Kind (Type)
import Data.Stream qualified as S
import Data.Text (Text) import Data.Text (Text)
import Data.Text qualified as T
data Expr phase data Expr phase
-- | A variable: `x`. -- | A variable: `x`.
@ -100,8 +114,8 @@ type family LetArgsF phase :: Type -> Type
type family CtrArgsF phase :: Type -> Type type family CtrArgsF phase :: Type -> Type
type family XExprF phase :: Type -> Type type family XExprF phase :: Type -> Type
data DefF r = DefF !Text !r data DefF r = Def !Text !r
deriving (Eq, Functor, Show) deriving (Eq, Functor, Foldable, Traversable, Show)
-- | A contractible data type with one extra type parameter. -- | A contractible data type with one extra type parameter.
data UnitF a = Unit data UnitF a = Unit
@ -199,10 +213,10 @@ class Functor (ExprF phase) => RecursivePhase phase where
embedXExpr = id embedXExpr = id
projectDef :: Def phase -> DefF (Expr phase) projectDef :: Def phase -> DefF (Expr phase)
projectDef = uncurry DefF projectDef = uncurry Def
embedDef :: DefF (Expr phase) -> Def phase embedDef :: DefF (Expr phase) -> Def phase
embedDef (DefF n e) = (n, e) embedDef (Def n e) = (n, e)
instance RecursivePhase phase => Recursive (Expr phase) where instance RecursivePhase phase => Recursive (Expr phase) where
project = \case project = \case
@ -227,3 +241,117 @@ instance RecursivePhase phase => Corecursive (Expr phase) where
--- ---
--- End base functor boilerplate. --- End base functor boilerplate.
--- ---
class Substitutable e where
-- | Fold over the variables in the expression with a monoid,
-- given what to do with variable usage sites and binding sites respectively.
collectVars :: Monoid m => (Text -> m) -> (Text -> m -> m) -> e -> m
-- | Free variables are variables which occur anywhere in an expression
-- where they are not bound by an abstraction.
free :: e -> HashSet Text
free = collectVars HS.singleton HS.delete
-- | Bound variables are variables which are abstracted over anywhere in an expression.
bound :: e -> HashSet Text
bound = collectVars (const HS.empty) HS.insert
-- | Used variables are variables which appear *anywhere* in an expression, free or bound.
used :: e -> HashSet Text
used = collectVars HS.singleton HS.insert
-- | Given a map between variable names and expressions,
-- replace each free occurrence of a variable with its respective expression.
substitute :: HashMap Text e -> e -> e
substitute substs = unsafeSubstitute substs . rename (foldMap free substs)
substitute1 :: Text -> e -> e -> e
substitute1 n e = substitute (HM.singleton n e)
-- | Rename all bound variables in an expression (both a binding sites and usage sites)
-- with new names where the new names are *not* members of the provided set.
rename :: HashSet Text -> e -> e
rename1 :: Text -> e -> e
rename1 n = rename (HS.singleton n)
-- | A variant of substitution which does *not* avoid variable capture;
-- it only gives the correct result if the bound variables in the body
-- are disjoint from the free variables in the argument.
unsafeSubstitute :: HashMap Text e -> e -> e
unsafeSubstitute1 :: Text -> e -> e -> e
unsafeSubstitute1 n e = unsafeSubstitute (HM.singleton n e)
--
-- These primitives are likely to be useful for implementing `rename`.
-- Ideally, I would like to find a way to move the implementation of `rename` here entirely,
-- but I haven't yet figured out an appropriate abstraction to do so.
--
-- | Run an action which requires a stateful context of used variable names
-- and a local context of variable replacements.
--
-- This is a useful monad for implementing the `rename` function.
runRenamer :: Substitutable e
=> (HashSet Text -> e -> StateT (HashSet Text) (Reader (HashMap Text Text)) a)
-> HashSet Text
-> e
-> a
runRenamer m ctx e = runReader (evalStateT (m ctx e) dirtyNames) HM.empty
where dirtyNames = HS.union ctx (used e)
-- | Create a new variable name within a context of used variable names.
freshVar :: MonadState (HashSet Text) m => Text -> m Text
freshVar baseName =
state \ctx -> let name = newName ctx in (name, HS.insert name ctx)
where
names = S.iterate (`T.snoc` '\'') baseName
newName ctx = S.head $ S.filter (not . (`HS.member` ctx)) names
-- | Replace a collection of old variable names with new variable names
-- and apply those replacements within a context.
replaceNames :: ( MonadReader (HashMap Text Text) m
, MonadState (HashSet Text) m
, MonadZip t, Traversable t
)
=> HashSet Text -> t Text -> m a -> m (t Text, a)
replaceNames badNames names m = do
newNames <- mapM freshVarIfNecessary names
let replacements = HM.filterWithKey (/=) $ fold $ mzipWith HM.singleton names newNames
x <- local (HM.union replacements) m
pure (newNames, x)
where
freshVarIfNecessary name
| name `HS.member` badNames = freshVar name
| otherwise = pure name
---
--- The same as the above section but for `substitute`.
--- This is useful when implementing substitution as a paramorphism.
---
-- | Run an action in a local context of substitutions.
--
-- This monad is useful for implementing, you guessed it, substitution.
runSubstituter :: (e -> Reader (HashMap Text e) a)
-> HashMap Text e
-> e
-> a
runSubstituter m substs e = runReader (m e) substs
-- | Apply only the substitutions which are not bound,
-- and only if there are substitutions left to apply.
maySubstitute :: ( MonadReader (HashMap Text b) m
, Functor t, Foldable t
)
=> t Text -> (a, m a) -> m a
maySubstitute ns (unmodified, substituted) =
local (compose $ fmap HM.delete ns) do
noMoreSubsts <- asks HM.null
if noMoreSubsts
then pure unmodified
else substituted
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id

View File

@ -1,8 +1,9 @@
module LambdaCalculus.Syntax.Base module LambdaCalculus.Syntax.Base
( Expr (..), ExprF (..), Ctr (..), Pat, Def, DefF (..), PatF (..), VoidF, Text, NonEmpty (..) ( Expr (..), ExprF (..), Ctr (..), Pat, Def, DefF (..), PatF (..), VoidF, Text, NonEmpty (..)
, substitute, substitute1, rename, rename1, free, bound, used
, Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..) , Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..)
, pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF , pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF
, pattern PChar, pattern PCharF, pattern PStr, pattern PStrF , pattern PChar, pattern PCharF, pattern PStr, pattern PStrF, pattern HoleP, pattern HoleFP
, simplify , simplify
) where ) where
@ -53,6 +54,8 @@ data ASTXF r
| PChar_ Char | PChar_ Char
-- | A string literal, e.g. `"abcd"`. -- | A string literal, e.g. `"abcd"`.
| PStr_ Text | PStr_ Text
-- | A type hole.
| HoleP_
deriving (Eq, Functor, Foldable, Traversable, Show) deriving (Eq, Functor, Foldable, Traversable, Show)
instance RecursivePhase Parse where instance RecursivePhase Parse where
@ -62,7 +65,6 @@ instance RecursivePhase Parse where
newtype NonEmptyDefFs r = NonEmptyDefFs { getNonEmptyDefFs :: NonEmpty (Text, r) } newtype NonEmptyDefFs r = NonEmptyDefFs { getNonEmptyDefFs :: NonEmpty (Text, r) }
deriving (Eq, Functor, Foldable, Traversable, Show) deriving (Eq, Functor, Foldable, Traversable, Show)
pattern LetFP :: NonEmpty (Text, r) -> r -> ASTF r pattern LetFP :: NonEmpty (Text, r) -> r -> ASTF r
pattern LetFP ds e = LetF (NonEmptyDefFs ds) e pattern LetFP ds e = LetF (NonEmptyDefFs ds) e
@ -90,10 +92,18 @@ pattern PStrF s = ExprXF (PStr_ s)
pattern PStr :: Text -> AST pattern PStr :: Text -> AST
pattern PStr s = ExprX (PStr_ s) pattern PStr s = ExprX (PStr_ s)
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, ExprXF #-} pattern HoleP :: AST
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, PNat, PList, PChar, PStr #-} pattern HoleP = ExprX HoleP_
{-# COMPLETE VarF, AppF, AbsF, LetF , CtrF, CaseF, PNatF, PListF, PCharF, PStrF #-}
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, PNatF, PListF, PCharF, PStrF #-} pattern HoleFP :: ASTF r
pattern HoleFP = ExprXF HoleP_
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, ExprXF #-}
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, PNat, PList, PChar, PStr, HoleP #-}
{-# COMPLETE VarF, AppF, AbsF, LetF , CtrF, CaseF, PNatF, PListF, PCharF, PStrF, HoleFP #-}
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, PNatF, PListF, PCharF, PStrF, HoleFP #-}
-- TODO: Implement Substitutable for AST.
-- | Combine nested expressions into compound expressions or literals when possible. -- | Combine nested expressions into compound expressions or literals when possible.
simplify :: AST -> AST simplify :: AST -> AST

View File

@ -86,10 +86,10 @@ ctr = pair <|> unit <|> either <|> nat <|> list <|> str
succ = Ctr CSucc [] <$ keyword "S" succ = Ctr CSucc [] <$ keyword "S"
natLit = (PNat . read <$> many1 digit) <* spaces natLit = (PNat . read <$> many1 digit) <* spaces
list = cons <|> consCtr <|> listLit list = cons <|> consCtr <|> listLit
consCtr = Ctr CCons [] <$ keyword "(:)" consCtr = Ctr CCons [] <$ keyword "(::)"
cons = try $ between (token '(') (token ')') do cons = try $ between (token '(') (token ')') do
e1 <- ambiguous e1 <- ambiguous
token ':' keyword "::"
e2 <- ambiguous e2 <- ambiguous
pure $ Ctr CCons [e1, e2] pure $ Ctr CCons [e1, e2]
listLit = fmap PList $ between (token '[') (token ']') $ sepEndBy ambiguous (token ',') listLit = fmap PList $ between (token '[') (token ']') $ sepEndBy ambiguous (token ',')
@ -105,35 +105,36 @@ pat = label "case alternate" $ do
keyword "->" keyword "->"
e <- ambiguous e <- ambiguous
pure $ Pat c ns e pure $ Pat c ns e
where pair = try $ between (token '(') (token ')') do where
e1 <- identifier pair = try $ between (token '(') (token ')') do
token ',' e1 <- identifier
e2 <- identifier token ','
pure (CPair, [e1, e2]) e2 <- identifier
unit = (CUnit, []) <$ keyword "()" pure (CPair, [e1, e2])
left = do unit = (CUnit, []) <$ keyword "()"
keyword "Left" left = do
e <- identifier keyword "Left"
pure (CLeft, [e]) e <- identifier
right = do pure (CLeft, [e])
keyword "Right" right = do
e <- identifier keyword "Right"
pure (CRight, [e]) e <- identifier
zero = (CZero, []) <$ keyword "Z" pure (CRight, [e])
succ = do zero = (CZero, []) <$ keyword "Z"
keyword "S" succ = do
e <- identifier keyword "S"
pure (CSucc, [e]) e <- identifier
nil = (CNil, []) <$ keyword "[]" pure (CSucc, [e])
cons = try $ between (token '(') (token ')') do nil = (CNil, []) <$ keyword "[]"
e1 <- identifier cons = try $ between (token '(') (token ')') do
token ':' e1 <- identifier
e2 <- identifier keyword "::"
pure (CCons, [e1, e2]) e2 <- identifier
char' = do pure (CCons, [e1, e2])
keyword "Char" char' = do
e <- identifier keyword "Char"
pure (CChar, [e]) e <- identifier
pure (CChar, [e])
case_ :: Parser AST case_ :: Parser AST
case_ = label "case patterns" $ do case_ = label "case patterns" $ do
@ -142,13 +143,16 @@ case_ = label "case patterns" $ do
token '}' token '}'
pure $ Case pats pure $ Case pats
hole :: Parser AST
hole = label "hole" $ HoleP <$ token '_'
-- | Guaranteed to consume a finite amount of input -- | Guaranteed to consume a finite amount of input
finite :: Parser AST finite :: Parser AST
finite = label "finite expression" $ variable <|> ctr <|> case_ <|> grouping finite = label "finite expression" $ variable <|> hole <|> ctr <|> case_ <|> grouping
-- | Guaranteed to consume input, but may continue until it reaches a terminator -- | Guaranteed to consume input, but may continue until it reaches a terminator
block :: Parser AST block :: Parser AST
block = label "block expression" $ finite <|> abstraction <|> let_ block = label "block expression" $ abstraction <|> let_ <|> finite
-- | Not guaranteed to consume input at all, may continue until it reaches a terminator -- | Not guaranteed to consume input at all, may continue until it reaches a terminator
ambiguous :: Parser AST ambiguous :: Parser AST

View File

@ -71,6 +71,7 @@ unparseAST = toStrict . toLazyText . snd . cata \case
in tag Finite $ "[" <> es' <> "]" in tag Finite $ "[" <> es' <> "]"
PStrF s -> tag Finite $ "\"" <> fromText s <> "\"" PStrF s -> tag Finite $ "\"" <> fromText s <> "\""
PCharF c -> tag Finite $ "'" <> fromLazyText (singleton c) PCharF c -> tag Finite $ "'" <> fromLazyText (singleton c)
HoleFP -> tag Finite "_"
where where
unparseApp :: Tagged Builder -> NonEmpty (Tagged Builder) -> Tagged Builder unparseApp :: Tagged Builder -> NonEmpty (Tagged Builder) -> Tagged Builder
unparseApp ef (unsnoc -> (exs, efinal)) unparseApp ef (unsnoc -> (exs, efinal))
@ -78,8 +79,8 @@ unparseAST = toStrict . toLazyText . snd . cata \case
unparseCtr :: Ctr -> [Tagged Builder] -> Tagged Builder unparseCtr :: Ctr -> [Tagged Builder] -> Tagged Builder
-- Fully-applied special syntax forms -- Fully-applied special syntax forms
unparseCtr CPair [x, y] = tag Finite $ "(" <> unambiguous x <> ", " <> unambiguous y <> ")" unparseCtr CPair [x, y] = tag Finite $ "(" <> unambiguous x <> ", " <> unambiguous y <> ")"
unparseCtr CCons [x, y] = tag Finite $ "(" <> unambiguous x <> " : " <> unambiguous y <> ")" unparseCtr CCons [x, y] = tag Finite $ "(" <> unambiguous x <> " :: " <> unambiguous y <> ")"
-- Partially-applied syntax forms -- Partially-applied syntax forms
unparseCtr CUnit [] = tag Finite "()" unparseCtr CUnit [] = tag Finite "()"
unparseCtr CPair [] = tag Finite "(,)" unparseCtr CPair [] = tag Finite "(,)"
@ -88,7 +89,7 @@ unparseAST = toStrict . toLazyText . snd . cata \case
unparseCtr CZero [] = tag Finite "Z" unparseCtr CZero [] = tag Finite "Z"
unparseCtr CSucc [] = tag Finite "S" unparseCtr CSucc [] = tag Finite "S"
unparseCtr CNil [] = tag Finite "[]" unparseCtr CNil [] = tag Finite "[]"
unparseCtr CCons [] = tag Finite "(:)" unparseCtr CCons [] = tag Finite "(::)"
unparseCtr CChar [] = tag Finite "Char" unparseCtr CChar [] = tag Finite "Char"
unparseCtr ctr (x:xs) = unparseApp (unparseCtr ctr []) (x :| xs) unparseCtr ctr (x:xs) = unparseApp (unparseCtr ctr []) (x :| xs)

158
src/LambdaCalculus/Types.hs Normal file
View File

@ -0,0 +1,158 @@
module LambdaCalculus.Types
( module LambdaCalculus.Types.Base
, infer
) where
import LambdaCalculus.Types.Base
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.Except (MonadError, throwError)
import Control.Monad.Reader (MonadReader, runReader, asks, local)
import Control.Monad.RWS
( RWST, evalRWST
, MonadState, state
, MonadWriter, tell, listen
)
import Data.Foldable (forM_, toList)
import Data.HashSet qualified as HS
import Data.HashMap.Strict qualified as HM
import Data.Stream (Stream (..), fromList)
import Data.Text qualified as T
fresh :: MonadState (Stream Text) m => m Type
fresh = state \(Cons n ns) -> (TVar n, ns)
inst :: MonadState (Stream Text) m => Scheme -> m Type
inst (TForall ns t) = foldr (\n_n t' -> substitute1 n_n <$> fresh <*> t') (pure t) ns
lookupVar :: (MonadReader Context m, MonadState (Stream Text) m, MonadError Text m) => Text -> m Type
lookupVar n = do
t_polyq <- asks (HM.!? n)
case t_polyq of
Nothing -> throwError $ "Variable not bound: " <> n
Just t_poly -> inst t_poly
generalize :: MonadReader Context m => Type -> m Scheme
generalize t = do
ctx <- asks HM.keysSet
pure $ TForall (toList $ HS.difference (free t) ctx) t
bindVar :: MonadReader Context m => Text -> Type -> m a -> m a
bindVar n t = local (HM.insert n (TForall [] t))
unify :: MonadWriter [Constraint] m => Type -> Type -> m ()
unify t1 t2 = tell [(t1, t2)]
ctrTy :: MonadState (Stream Text) m => Ctr -> m (Type, [Type])
ctrTy = \case
CUnit -> pure (TUnit, [])
CZero -> pure (TNat, [])
CSucc -> pure (TNat, [TNat])
CChar -> pure (TChar, [TNat])
CNil -> mkUnary TList $ const []
CCons -> mkUnary TList \t_a -> [t_a, TApp TList t_a]
CPair -> mkBinary TProd \t_a t_b -> [t_a, t_b]
CLeft -> mkBinary TSum \t_a _ -> [t_a]
CRight -> mkBinary TSum \_ t_b -> [t_b]
where
mkBinary tc tcas = do
t_a <- fresh
t_b <- fresh
pure (tapp [tc, t_a, t_b], tcas t_a t_b)
mkUnary tc tcas = do
t_a <- fresh
pure (TApp tc t_a, tcas t_a)
j :: (MonadError Text m, MonadReader Context m, MonadState (Stream Text) m, MonadWriter [Constraint] m)
=> CheckExpr -> m Type
j (Var name) = lookupVar name
j (App e_fun e_arg) = do
t_ret <- fresh
t_fun <- j e_fun
t_arg <- j e_arg
unify t_fun (tapp [TAbs, t_arg, t_ret])
pure t_ret
j (Abs n_arg e_ret) = do
t_arg <- fresh
t_ret <- bindVar n_arg t_arg $ j e_ret
pure $ tapp [TAbs, t_arg, t_ret]
j (Let (n_x, e_x) e_ret) = do
(t_x_mono, c) <- listen $ j e_x
s <- solve' c
t_x_poly <- generalize $ substitute s t_x_mono
local (HM.insert n_x t_x_poly) $ j e_ret
-- In a case expression:
-- * the pattern for each branch has the same type as the expression being matched, and
-- * the return type for each branch has the same type as the return type of the case expression as a whole.
j (Case ctrs) = do
t_ret <- fresh
t_x <- fresh
forM_ ctrs \(Pat ctr ns_n e) -> do
(t_x', ts_n) <- ctrTy ctr
unify t_x t_x'
when (length ts_n /= length ns_n) $ throwError "Constructor arity mismatch"
t_ret' <- local (HM.union $ HM.fromList $ zip ns_n $ map (TForall []) ts_n) $ j e
unify t_ret t_ret'
pure $ tapp [TAbs, t_x, t_ret]
j (CtrC ctr) = do
(t_ret, ts_n) <- ctrTy ctr
pure $ foldr (\t_a t_r -> tapp [TAbs, t_a, t_r]) t_ret ts_n
j CallCCC = do
t_a <- fresh
t_b <- fresh
pure $ tapp [TAbs, tapp [TAbs, tapp [TAbs, t_a, t_b], t_a], t_a]
j FixC = do
t_a <- fresh
pure $ tapp [TAbs, tapp [TAbs, t_a, t_a], t_a]
j HoleC = asks show >>= throwError . (<>) "Encountered hole with context: " . T.pack
occurs :: Text -> Type -> Bool
occurs n t = HS.member n (free t)
findDifference :: MonadError (Type, Type) m => Type -> Type -> m (Maybe (Text, Type))
findDifference t1 t2
| t1 == t2 = pure Nothing
| TVar n1 <- t1, not (occurs n1 t2) = pure $ Just (n1, t2)
| TVar _ <- t2 = findDifference t2 t1
| TApp a1 b1 <- t1, TApp a2 b2 <- t2 = (<|>) <$> findDifference a1 a2 <*> findDifference b1 b2
| otherwise = throwError (t1, t2)
unifies :: MonadError (Type, Type) m => Type -> Type -> m Substitution
unifies t1 t2 = do
dq <- findDifference t1 t2
case dq of
Nothing -> pure HM.empty
Just s -> do
ss <- unifies (uncurry substitute1 s t1) (uncurry substitute1 s t2)
pure $ uncurry HM.insert (fmap (substitute ss) s) ss
solve :: MonadError (Type, Type) m => [Constraint] -> m Substitution
solve [] = pure HM.empty
solve (c:cs) = do
s <- uncurry unifies c
ss <- solve (substituteMono s cs)
pure $ HM.union ss (substituteMono ss s)
solve' :: MonadError Text m => [Constraint] -> m Substitution
solve' c = case solve c of
Right ss -> pure ss
Left (t1, t2) -> throwError $ "Could not unify " <> unparseType t1 <> " with " <> unparseType t2
type Inferencer a = RWST Context [Constraint] (Stream Text) (Either Text) a
runInferencer :: Inferencer a -> Either Text (a, [Constraint])
runInferencer m = evalRWST m HM.empty freshNames
where
freshNames = fromList $ do
n <- [0 :: Int ..]
c <- ['a'..'z']
pure $ T.pack if n == 0 then [c] else c : show n
infer :: CheckExpr -> Either Text Scheme
infer e = do
(t, c) <- runInferencer $ j e
s <- solve' c
let t' = substitute s t
pure $ runReader (generalize t') HM.empty

View File

@ -0,0 +1,242 @@
{-# LANGUAGE TemplateHaskell #-}
module LambdaCalculus.Types.Base
( Identity (..)
, Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text
, substitute, substitute1, rename, rename1, free, bound, used
, Check, CheckExpr, CheckExprF, CheckX, CheckXF (..)
, pattern AppFC, pattern CtrC, pattern CtrFC, pattern CallCCC, pattern CallCCFC
, pattern FixC, pattern FixFC, pattern HoleC, pattern HoleFC
, Type (..), TypeF (..), Scheme (..), tapp
, Substitution, Context, Constraint
, MonoSubstitutable, substituteMono, substituteMono1
, unparseType, unparseScheme
) where
import Control.Monad (forM)
import Control.Monad.Reader (asks)
import Data.Bifunctor (bimap, first)
import Data.Foldable (fold)
import Data.Functor.Foldable (embed, cata, para)
import Data.Functor.Foldable.TH (makeBaseFunctor)
import Data.Functor.Identity (Identity (..))
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.List (foldl1')
import Data.Text qualified as T
import Data.Traversable (for)
import LambdaCalculus.Expression.Base
data Check
type CheckExpr = Expr Check
type instance AppArgs Check = CheckExpr
type instance AbsArgs Check = Text
type instance LetArgs Check = (Text, CheckExpr)
type instance CtrArgs Check = UnitF CheckExpr
type instance XExpr Check = CheckX
type CheckX = CheckXF CheckExpr
type CheckExprF = ExprF Check
type instance AppArgsF Check = Identity
type instance LetArgsF Check = DefF
type instance CtrArgsF Check = UnitF
type instance XExprF Check = CheckXF
data CheckXF r
-- | Call-with-current-continuation.
= CallCCC_
-- | A fixpoint combinator,
-- because untyped lambda calculus fixpoint combinators won't typecheck.
| FixC_
-- | A hole to ask the type inferencer about the context for debugging purposes.
| HoleC_
deriving (Eq, Functor, Foldable, Traversable, Show)
pattern CtrC :: Ctr -> CheckExpr
pattern CtrC c = Ctr c Unit
pattern CtrFC :: Ctr -> CheckExprF r
pattern CtrFC c = CtrF c Unit
pattern AppFC :: r -> r -> CheckExprF r
pattern AppFC ef ex = AppF ef (Identity ex)
pattern CallCCC :: CheckExpr
pattern CallCCC = ExprX CallCCC_
pattern CallCCFC :: CheckExprF r
pattern CallCCFC = ExprXF CallCCC_
pattern FixC :: CheckExpr
pattern FixC = ExprX FixC_
pattern FixFC :: CheckExprF r
pattern FixFC = ExprXF FixC_
pattern HoleC :: CheckExpr
pattern HoleC = ExprX HoleC_
pattern HoleFC :: CheckExprF r
pattern HoleFC = ExprXF HoleC_
{-# COMPLETE Var, App, Abs, Let, CtrC, Case, ExprX #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFC, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrF, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrFC, CaseF, ExprXF #-}
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, CallCCC, FixC, HoleC #-}
{-# COMPLETE Var, App, Abs, Let, CtrC, Case, CallCCC, FixC, HoleC #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFC, CaseF, CallCCFC, FixFC, HoleFC #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrF, CaseF, CallCCFC, FixFC, HoleFC #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrFC, CaseF, CallCCFC, FixFC, HoleFC #-}
instance RecursivePhase Check where
projectAppArgs = Identity
projectLetArgs = projectDef
embedAppArgs = runIdentity
embedLetArgs = embedDef
instance Substitutable CheckExpr where
collectVars withVar withBinder = cata \case
VarF n -> withVar n
AbsF n e -> withBinder n e
LetF (Def n x) e -> x <> withBinder n e
CaseF pats -> foldMap (\(Pat _ ns e) -> foldr withBinder e ns) pats
e -> fold e
rename = runRenamer $ \badNames -> cata \case
VarF n -> asks $ Var . HM.findWithDefault n n
AbsF n e -> uncurry Abs . first runIdentity <$> replaceNames badNames (Identity n) e
LetF (Def n x) e -> do
x' <- x
(Identity n', e') <- replaceNames badNames (Identity n) e
pure $ Let (n', x') e'
CaseF ps -> Case <$> forM ps \(Pat ctr ns e) ->
uncurry (Pat ctr) <$> replaceNames badNames ns e
e -> embed <$> sequenceA e
unsafeSubstitute = runSubstituter $ para \case
VarF name -> asks $ HM.findWithDefault (Var name) name
AbsF name e -> Abs name <$> maySubstitute (Identity name) e
LetF (Def name (_, x)) e -> do
x' <- x
e' <- maySubstitute (Identity name) e
pure $ Let (name, x') e'
CaseF pats -> Case <$> for pats \(Pat ctr ns e) -> Pat ctr ns <$> maySubstitute ns e
e -> embed <$> traverse snd e
-- | A monomorphic type.
data Type
-- | Type variable.
= TVar Text
-- | Type application.
| TApp Type Type
-- | The function type.
| TAbs
-- | The product type.
| TProd
-- | The sum type.
| TSum
-- | The unit type.
| TUnit
-- | The empty type.
| TVoid
-- | The type of natural numbers.
| TNat
-- | The type of lists.
| TList
-- | The type of characters.
| TChar
deriving (Eq, Show)
makeBaseFunctor ''Type
instance Substitutable Type where
collectVars withVar _ = cata \case
TVarF n -> withVar n
t -> fold t
-- /All/ variables in a monomorphic type are free.
rename _ t = t
-- No renaming step is necessary.
substitute substs = cata \case
TVarF n -> HM.findWithDefault (TVar n) n substs
e -> embed e
unsafeSubstitute = substitute
-- | A polymorphic type.
data Scheme
-- | Universally quantified type variables.
= TForall [Text] Type
deriving (Eq, Show)
instance Substitutable Scheme where
collectVars withVar withBinder (TForall names t) =
foldMap withBinder names $ collectVars withVar withBinder t
rename = runRenamer \badNames (TForall names t) ->
uncurry TForall <$> replaceNames badNames names (pure t)
-- I took a shot at implementing this but found it to be quite difficult
-- because merging the foralls is tricky.
-- It's not undoable, but it wasn't worth my further time investment
-- seeing as this function isn't currently used anywhere.
unsafeSubstitute = error "Substitution for schemes not yet implemented"
type Substitution = HashMap Text Type
type Context = HashMap Text Scheme
type Constraint = (Type, Type)
class MonoSubstitutable t where
substituteMono :: Substitution -> t -> t
substituteMono1 :: Text -> Type -> t -> t
substituteMono1 var val= substituteMono (HM.singleton var val)
instance MonoSubstitutable Type where
substituteMono = substitute
instance MonoSubstitutable Scheme where
substituteMono substs (TForall names t) =
TForall names $ substitute (foldMap HM.delete names substs) t
instance MonoSubstitutable Constraint where
substituteMono substs = bimap (substituteMono substs) (substituteMono substs)
instance MonoSubstitutable t => MonoSubstitutable [t] where
substituteMono = fmap . substituteMono
instance MonoSubstitutable t => MonoSubstitutable (HashMap a t) where
substituteMono = fmap . substituteMono
tapp :: [Type] -> Type
tapp [] = error "Empty type applications are not permitted"
tapp [t] = t
tapp ts = foldl1' TApp ts
-- HACK
pattern TApp2 :: Type -> Type -> Type -> Type
pattern TApp2 tf tx ty = TApp (TApp tf tx) ty
-- TODO: Improve these printers.
unparseType :: Type -> Text
unparseType (TVar name) = name
unparseType (TApp2 TAbs a b) = "(" <> unparseType a <> " -> " <> unparseType b <> ")"
unparseType (TApp2 TProd a b) = "(" <> unparseType a <> " * " <> unparseType b <> ")"
unparseType (TApp2 TSum a b) = "(" <> unparseType a <> " + " <> unparseType b <> ")"
unparseType (TApp TList a) = "[" <> unparseType a <> "]"
unparseType (TApp a b) = "(" <> unparseType a <> " " <> unparseType b <> ")"
unparseType TAbs = "(->)"
unparseType TProd = "(*)"
unparseType TSum = "(+)"
unparseType TUnit = ""
unparseType TVoid = ""
unparseType TNat = "Nat"
unparseType TList = "[]"
unparseType TChar = "Char"
unparseScheme :: Scheme -> Text
unparseScheme (TForall [] t) = unparseType t
unparseScheme (TForall names t) = "" <> T.unwords names <> ". " <> unparseType t

View File

@ -39,10 +39,10 @@ omega = App x x
where x = Abs "x" (App (Var "x") (Var "x")) where x = Abs "x" (App (Var "x") (Var "x"))
cc1 :: EvalExpr cc1 :: EvalExpr
cc1 = App CallCC (Abs "k" (App omega (App (Var "k") (Var "z")))) cc1 = App CallCCE (Abs "k" (App omega (App (Var "k") (Var "z"))))
cc2 :: EvalExpr cc2 :: EvalExpr
cc2 = App (Var "y") (App CallCC (Abs "k" (App (Var "z") (App (Var "k") (Var "x"))))) cc2 = App (Var "y") (App CallCCE (Abs "k" (App (Var "z") (App (Var "k") (Var "x")))))
main :: IO () main :: IO ()
main = defaultMain $ main = defaultMain $