Added support for Hindley-Milner type inference!

I also generalized substitution to *all* expressions (and types),
which itself involved a rewrite of the substitution function I already had.
(In hindsight, I wish I'd done that in a separate commit.)

The last four days have all been working up to this,
so I'm glad I've finally managed to integrate it into this project!
(I wrote the algorithm 4 days ago, but the infrastructure
just wasn't there to add it here.)
master
James T. Martin 2021-03-18 00:00:43 -07:00
parent 586be18c80
commit 9e0754daf6
Signed by: james
GPG Key ID: 4B7F3DA9351E577C
15 changed files with 806 additions and 254 deletions

View File

@ -1,11 +1,14 @@
# Lambda Calculus
This is a simple programming language derived from lambda calculus.
This is a simple programming language derived from lambda calculus,
using the Hindley-Milner type system, plus some builtin types, `fix`, and `callcc`
## Usage
Run the program using `stack run` (or run the tests with `stack test`).
Type in your expression at the prompt: `>> `.
The expression will be evaluated to normal form using the call-by-value evaluation strategy and then printed.
Type in your expression at the prompt: `>> `. This will happen:
* the type for the expression will be inferred and then printed,
* then, regardless of whether typechecking succeeded, expression will be evaluated to normal form using the call-by-value evaluation strategy and then printed.
Exit the prompt with `Ctrl-c` (or equivalent).
## Syntax
@ -16,23 +19,38 @@ The parser's error messages currently are virtually useless, so be very careful
* Lambda abstraction: `\x y z. E` or `λx y z. E`
* Let expressions: `let x = E; y = F in G`
* Parenthetical expressions: `(E)`
* Constructors: `()`, `(x, y)` (or `(,) x y`), `Left x`, `Right y`, `Z`, `S`, `[]`, `(x : xs)` (or `(:) x xs`), `Char n`.
* Constructors: `()`, `(x, y)` (or `(,) x y`), `Left x`, `Right y`, `Z`, `S`, `[]`, `(x :: xs)` (or `(:) x xs`), `Char n`.
* The parentheses around the cons constructor are not optional.
* `Char` takes a natural number and turns it into a character.
* Pattern matchers: `{ Left x -> e ; Right y -> f }`
* Pattern matchers: `case { Left a -> e ; Right y -> f }`
* Pattern matchers can be applied like functions, e.g. `{ Z -> x, S -> y } 10` reduces to `y`.
* Patterns must use the regular form of the constructor, e.g. `(x : xs)` and not `((:) x xs)`.
* Patterns must use the regular form of the constructor, e.g. `(x :: xs)` and not `((::) x xs)`.
* There are no nested patterns or default patterns.
* Incomplete pattern matches will crash the interpreter.
* Literals: `1234`, `[e, f, g, h]`, `'a`, `"abc"`
* Strings are represented as lists of characters.
* Type annotations: there are no type annotations; types are inferred only.
## Call/CC
This interpreter has preliminary support for
[the call-with-current-continuation control flow operator](https://en.wikipedia.org/wiki/Call-with-current-continuation).
However, it has not been thoroughly tested.
## Types
Types are checked/inferred using the Hindley-Milner type inference algorithm.
To use it, simply apply the variable `callcc` like you would a function, e.g. `(callcc (\k. ...))`.
* Functions: `a -> b` (constructed by `\x. e`)
* Products: `a * b` (constructed by `(x, y)`)
* Unit: `★` (constructed by `()`)
* Sums: `a + b` (constructed by `Left x` or `Right y`)
* Bottom: `⊥` (currently useless because incomplete patterns are allowed)
* The natural numbers: `Nat` (constructed by `Z` and `S`)
* Lists: `List a` (constructed by `[]` and `(x :: xs)`)
* Characters: `Char` (constructed by `Char`, which takes a `Nat`)
* Universal quantification (forall): `∀a b. t`
## Builtins
Builtins are variables that correspond with a built-in language feature
that cannot be replicated by user-written code.
They still are just variables though; they do not receive special syntactic treatment.
* `fix : ∀a. ((a -> a) -> a)`: an alias for the strict fixpoint combinator that the typechecker can typecheck.
* `callcc : ∀a b. (((a -> b) -> a) -> a)`: [the call-with-current-continuation control flow operator](https://en.wikipedia.org/wiki/Call-with-current-continuation).
Continuations are printed as `λ!. ... ! ...`, like a lambda abstraction
with an argument named `!` which is used exactly once;
@ -41,56 +59,61 @@ because they perform the side effect of modifying the current continuation,
and this is *not* valid syntax you can input into the REPL.
## Example code
The fixpoint function:
```
(\x. x x) \fix f x. f (fix fix f) x
```
Create a list by iterating `f` `n` times:
```
fix \iterate f x. { Z -> x ; S n -> iterate f (f x) n }
fix \iterate f x. { Z -> [] ; S n -> (x :: iterate f (f x) n) }
: ∀e. ((e -> e) -> (e -> (Nat -> [e])))
```
Create a list whose first element is `n - 1`, counting down to a last element of `0`:
Use the iterate function to count to 10:
```
\n. { (n, x) -> x } (iterate { (n, x) -> (S n, (n : x)) } (0, []) n)
```
Putting it all together to count down from 10:
```
>> let fix = (\x. x x) \fix f x. f (fix fix f) x; iterate = fix \iterate f x. { Z -> x ; S n -> iterate f (f x) n }; countDownFrom = \n. { (n, x) -> x } (iterate { (n, x) -> (S n, (n : x)) } (0, []) n) in countDownFrom 10
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
>> let iterate = fix \iterate f x. { Z -> [] ; S n -> (x :: iterate f (f x) n) }; countTo = iterate S 1 in countTo 10
: [Nat]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
```
Append two lists together:
```
fix \append xs ys. { [] -> ys ; (x : xs) -> (x : append xs ys) } xs
fix \append xs ys. { [] -> ys ; (x :: xs) -> (x :: append xs ys) } xs
: ∀j. ([j] -> ([j] -> [j]))
```
Reverse a list:
```
fix \reverse. { [] -> [] ; (x : xs) -> append (reverse xs) [x] }
fix \reverse. { [] -> [] ; (x :: xs) -> append (reverse xs) [x] }
: ∀c1. ([c1] -> [c1])
```
Putting them together so we can reverse `"reverse"`:
```
>> let fix = (\x. x x) \fix f x. f (fix fix f) x; append = fix \append xs ys. { [] -> ys ; (x : xs) -> (x : append xs ys) } xs; reverse = fix \reverse. { [] -> [] ; (x : xs) -> append (reverse xs) [x] } in reverse "reverse"
>> let append = fix \append xs ys. { [] -> ys ; (x :: xs) -> (x :: append xs ys) } xs; reverse = fix \reverse. { [] -> [] ; (x :: xs) -> append (reverse xs) [x] } in reverse "reverse"
: [Char]
"esrever"
```
Calculating `3 + 2` with the help of Church-encoded numerals:
```
>> let Sf = \n f x. f (n f x); plus = \x. x Sf in plus (\f x. f (f (f x))) (\f x. f (f x)) S Z
: Nat
5
```
This expression would loop forever, but `callcc` saves the day!
```
>> y (callcc \k. (\x. (\x. x x) (\x. x x)) (k z))
y z
>> S (callcc \k. (fix \x. x) (k Z))
: Nat
1
```
A few other weird expressions:
And if it wasn't clear, this is what the `Char` constructor does:
```
>> { Char c -> Char (S c) } 'a
: Char
'b
```
Here are a few expressions which don't typecheck but are handy for debugging the evaluator:
```
>> let D = \x. x x; F = \f. f (f y) in D (F \x. x)
y y
@ -98,6 +121,4 @@ y y
y
>> (\x y z. x y) y
λy' z. y y'
>> { Char c -> Char (S c) } 'a
'b
```

View File

@ -15,8 +15,10 @@ prompt text = do
getLine
main :: IO ()
main = forever $ parseEval <$> prompt ">> " >>= \case
main = forever $ parseCheck <$> prompt ">> " >>= \case
Left parseError -> putStrLn $ "Parse error: " <> pack (show parseError)
-- TODO: Support choosing which version to use at runtime.
Right expr -> putStrLn $ unparseEval $ eval expr
--Right expr -> mapM_ (putStrLn . unparseEval) $ snd $ traceEval expr
Right expr -> do
either putStrLn (putStrLn . (": " <>) . unparseScheme) $ infer expr
putStrLn $ unparseEval $ eval $ check2eval expr
--mapM_ (putStrLn . unparseEval) $ snd $ traceEval $ check2eval expr

View File

@ -20,9 +20,11 @@ default-extensions:
- FlexibleContexts
- FlexibleInstances
- ImportQualifiedPost
- InstanceSigs
- LambdaCase
- OverloadedStrings
- PatternSynonyms
- ScopedTypeVariables
- StandaloneDeriving
- ViewPatterns
# Required for use of the 'trees that grow' pattern
@ -32,7 +34,6 @@ default-extensions:
- DeriveFoldable
- DeriveFunctor
- DeriveTraversable
- TemplateHaskell
dependencies:
- base >= 4.14 && < 5

View File

@ -2,15 +2,23 @@ module LambdaCalculus
( module LambdaCalculus.Evaluator
, module LambdaCalculus.Expression
, module LambdaCalculus.Syntax
, parseEval, unparseEval
, module LambdaCalculus.Types
, parseCheck, parseEval, unparseCheck, unparseEval
) where
import LambdaCalculus.Evaluator
import LambdaCalculus.Expression
import LambdaCalculus.Syntax
import LambdaCalculus.Types
parseCheck :: Text -> Either ParseError CheckExpr
parseCheck = fmap ast2check . parseAST
parseEval :: Text -> Either ParseError EvalExpr
parseEval = fmap ast2eval . parseAST
unparseCheck :: CheckExpr -> Text
unparseCheck = unparseAST . simplify . check2ast
unparseEval :: EvalExpr -> Text
unparseEval = unparseAST . simplify . eval2ast

View File

@ -2,130 +2,29 @@ module LambdaCalculus.Evaluator
( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text
, Eval, EvalExpr, EvalX, EvalXF (..)
, pattern AppFE, pattern CtrE, pattern CtrFE
, pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF
, eval, traceEval, substitute, alphaConvert
, pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE
, eval, traceEval
) where
import LambdaCalculus.Evaluator.Base
import LambdaCalculus.Evaluator.Continuation
import Control.Monad (forM)
import Control.Monad.Except (MonadError, ExceptT, throwError, runExceptT)
import Control.Monad.State (MonadState, State, evalState, modify', state, put, gets)
import Control.Monad.State (MonadState, evalState, modify', state, put, gets)
import Control.Monad.Writer (runWriterT, tell)
import Data.Foldable (fold)
import Data.Functor.Foldable (cata, para, embed)
import Data.HashSet (HashSet)
import Data.HashSet qualified as HS
import Data.Stream qualified as S
import Data.Text qualified as T
import Data.HashMap.Strict qualified as HM
import Data.Void (Void, absurd)
-- | Free variables are variables which are present in an expression but not bound by any abstraction.
freeVars :: EvalExpr -> HashSet Text
freeVars = cata \case
VarF n -> HS.singleton n
AbsF n e -> HS.delete n e
ContF e -> HS.delete "!" e
CaseF ps -> foldMap (\(Pat _ ns e) -> HS.difference e (HS.fromList ns)) ps
e -> fold e
-- | Bound variables are variables which are bound by any form of abstraction in an expression.
boundVars :: EvalExpr -> HashSet Text
boundVars = cata \case
AbsF n e -> HS.insert n e
ContF e -> HS.insert "!" e
CaseF ps -> foldMap (\(Pat _ ns e) -> HS.union (HS.fromList ns) e) ps
e -> fold e
-- | Vars that occur anywhere in an experession, bound or free.
usedVars :: EvalExpr -> HashSet Text
usedVars x = HS.union (freeVars x) (boundVars x)
-- | Substitution is the process of replacing all free occurrences of a variable in one expression with another expression.
substitute :: Text -> EvalExpr -> EvalExpr -> EvalExpr
substitute var val = unsafeSubstitute var val . alphaConvert (freeVars val)
-- | Substitution is only safe if the bound variables in the body
-- are disjoint from the free variables in the argument;
-- this function makes an expression body safe for substitution
-- by replacing the bound variables in the body
-- with completely new variables which do not occur in either expression
-- (without changing any *free* variables in the body, of course).
alphaConvert :: HashSet Text -> EvalExpr -> EvalExpr
alphaConvert ctx e_ = evalState (alphaConverter e_) $ HS.union ctx (usedVars e_)
where
alphaConverter :: EvalExpr -> State (HashSet Text) EvalExpr
alphaConverter = cata \case
e
| AbsF n e' <- e, n `HS.member` ctx -> do
n' <- fresh n
e'' <- e'
pure $ Abs n' $ replace n n' e''
-- | TODO: Only replace the names that *have* to be replaced.
| CaseF ps <- e, any (any (`HS.member` ctx) . patNames) ps ->
Case <$> forM ps \(Pat ctr ns e') -> do
ns' <- mapM fresh ns
e'' <- e'
pure $ Pat ctr ns' $ foldr (uncurry replace) e'' (zip ns ns')
| otherwise -> embed <$> sequenceA e
-- | Create a new name which is not used anywhere else.
fresh :: Text -> State (HashSet Text) Text
fresh n = state \ctx' ->
let n' = S.head $ S.filter (not . (`HS.member` ctx')) names
in (n', HS.insert n' ctx')
where names = S.iterate (`T.snoc` '\'') n
-- | Replace a name with an entirely new name in all contexts.
-- This will only give correct results if
-- the new name does not occur anywhere in the expression.
replace :: Text -> Text -> EvalExpr -> EvalExpr
replace name name' = cata \case
e
| VarF name2 <- e, name == name2 -> Var name'
| AbsF name2 e' <- e, name == name2 -> Abs name' e'
| CaseF ps <- e -> Case $ flip map ps \(Pat ctr ns e') -> Pat ctr (replace' ns) e'
| otherwise -> embed e
where
replace' = map \case
n
| n == name -> name'
| otherwise -> n
-- | Substitution which does *not* avoid variable capture;
-- it only gives the correct result if the bound variables in the body
-- are disjoint from the free variables in the argument.
unsafeSubstitute :: Text -> EvalExpr -> EvalExpr -> EvalExpr
unsafeSubstitute var val = para \case
e'
| VarF var2 <- e', var == var2 -> val
| AbsF var2 _ <- e', var == var2 -> unmodified e'
| ContF _ <- e', var == "!" -> unmodified e'
| CaseF ps <- e' -> Case $ flip map ps \(Pat ctr ns (unmod, sub)) ->
Pat ctr ns if var `elem` ns then unmod else sub
| otherwise -> substituted e'
where
substituted, unmodified :: EvalExprF (EvalExpr, EvalExpr) -> EvalExpr
substituted = embed . fmap snd
unmodified = embed . fmap fst
isReducible :: EvalExpr -> Bool
isReducible = snd . cata \case
AppFE ctr args -> active ctr [args]
AbsF _ _ -> passive
ContF _ -> passive
CaseF _ -> passive
CallCCF -> passive
CtrFE _ -> constant
VarF _ -> constant
where
-- | Constants are irreducible in any context.
constant = (False, False)
-- | Passive expressions are reducible only if an active expression is applied to them.
passive = (True, False)
-- | Active expressions are reducible if they are applied to a constructor or their arguments are reducible.
active ctr args = (False, fst ctr || snd ctr || any snd args)
-- Applications of function type constructors
isReducible (App (Abs _ _) _) = True
isReducible (App (ContE _) _) = True
isReducible (App CallCCE _) = True
-- Pattern matching of data
isReducible (App (Case _) ex) = isData ex || isReducible ex
-- Reducible subexpressions
isReducible (App ef ex) = isReducible ef || isReducible ex
isReducible _ = False
lookupPat :: Ctr -> [Pat phase] -> Pat phase
lookupPat ctr = foldr lookupCtr' (error "Constructor not found")
@ -182,21 +81,22 @@ evaluatorStep = \case
| otherwise -> case ef of
-- perform beta reduction if possible...
Abs name body ->
pure $ substitute name ex body
Case pats
| isData ex -> do
let (ctr, xs) = toData ex
let Pat _ ns e = lookupPat ctr pats
pure $ foldr (uncurry substitute) e (zip ns xs)
| otherwise -> ret unmodified
pure $ substitute1 name ex body
Case pats ->
if isData ex
then do
let (ctr, xs) = toData ex
let Pat _ ns e = lookupPat ctr pats
pure $ substitute (HM.fromList $ zip ns xs) e
else ret unmodified
-- perform continuation calls if possible...
Cont body -> do
ContE body -> do
put []
pure $ substitute "!" ex body
pure $ substitute1 "!" ex body
-- capture the current continuation if requested...
CallCC -> do
CallCCE -> do
k <- gets $ continue (Var "!")
pure $ App ex (Cont k)
pure $ App ex (ContE k)
-- otherwise the value is irreducible and we can continue evaluation.
_ -> ret unmodified
-- Neither abstractions, constructors nor variables are reducible.

View File

@ -1,14 +1,22 @@
module LambdaCalculus.Evaluator.Base
( Identity (..)
, Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text
, substitute, substitute1, rename, rename1, free, bound, used
, Eval, EvalExpr, EvalExprF, EvalX, EvalXF (..)
, pattern AppFE, pattern CtrE, pattern CtrFE
, pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF
, pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE
) where
import LambdaCalculus.Expression.Base
import Control.Monad (forM)
import Control.Monad.Reader (asks)
import Data.Bifunctor (first)
import Data.Foldable (fold)
import Data.Functor.Identity (Identity (..))
import Data.Functor.Foldable (embed, cata, para)
import Data.HashMap.Strict qualified as HM
import Data.Traversable (for)
data Eval
type EvalExpr = Expr Eval
@ -33,41 +41,62 @@ data EvalXF r
--
-- Continuations do not have any corresponding surface-level syntax,
-- but may be printed like a lambda with the illegal variable `!`.
= Cont_ !r
= ContE_ !r
-- | Call-with-current-continuation, an evaluator built-in function.
| CallCC_
| CallCCE_
deriving (Eq, Functor, Foldable, Traversable, Show)
instance RecursivePhase Eval where
projectAppArgs = Identity
embedAppArgs = runIdentity
pattern CtrE :: Ctr -> EvalExpr
pattern CtrE c = Ctr c Unit
pattern CtrFE :: Ctr -> EvalExprF r
pattern CtrFE c = CtrF c Unit
pattern Cont :: EvalExpr -> EvalExpr
pattern Cont e = ExprX (Cont_ e)
pattern ContE :: EvalExpr -> EvalExpr
pattern ContE e = ExprX (ContE_ e)
pattern CallCC :: EvalExpr
pattern CallCC = ExprX CallCC_
pattern CallCCE :: EvalExpr
pattern CallCCE = ExprX CallCCE_
pattern ContF :: r -> EvalExprF r
pattern ContF e = ExprXF (Cont_ e)
pattern ContFE :: r -> EvalExprF r
pattern ContFE e = ExprXF (ContE_ e)
pattern CallCCF :: EvalExprF r
pattern CallCCF = ExprXF CallCC_
pattern CallCCFE :: EvalExprF r
pattern CallCCFE = ExprXF CallCCE_
pattern AppFE :: r -> r -> EvalExprF r
pattern AppFE ef ex = AppF ef (Identity ex)
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, Cont, CallCC #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrF, CaseF, ContF, CallCCF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ContF, CallCCF #-}
{-# COMPLETE Var, App, Abs, Let, CtrE, Case, Cont, CallCC #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFE, CaseF, ContF, CallCCF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ContF, CallCCF #-}
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, ContE, CallCCE #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrF, CaseF, ContFE, CallCCFE #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrF, CaseF, ContFE, CallCCFE #-}
{-# COMPLETE Var, App, Abs, Let, CtrE, Case, ContE, CallCCE #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFE, CaseF, ContFE, CallCCFE #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFE, AbsF, LetF, CtrFE, CaseF, ContFE, CallCCFE #-}
instance RecursivePhase Eval where
projectAppArgs = Identity
embedAppArgs = runIdentity
instance Substitutable EvalExpr where
collectVars withVar withBinder = cata \case
VarF n -> withVar n
AbsF n e -> withBinder n e
CaseF pats -> foldMap (\(Pat _ ns e) -> foldr withBinder e ns) pats
e -> fold e
rename = runRenamer $ \badNames -> cata \case
VarF n -> asks $ Var . HM.findWithDefault n n
AbsF n e -> uncurry Abs . first runIdentity <$> replaceNames badNames (Identity n) e
ContFE e -> ContE <$> e
CaseF ps -> Case <$> forM ps \(Pat ctr ns e) -> uncurry (Pat ctr) <$> replaceNames badNames ns e
e -> embed <$> sequenceA e
unsafeSubstitute = runSubstituter $ para \case
VarF name -> asks $ HM.findWithDefault (Var name) name
AbsF name e -> Abs name <$> maySubstitute (Identity name) e
ContFE e -> ContE <$> maySubstitute (Identity "!") e
CaseF pats -> Case <$> for pats \(Pat ctr ns e) -> Pat ctr ns <$> maySubstitute ns e
e -> embed <$> traverse snd e

View File

@ -1,6 +1,6 @@
module LambdaCalculus.Evaluator.Continuation
( Continuation, continue, continue1
, ContinuationCrumb (ApplyTo, AppliedTo, AbstractedOver)
, ContinuationCrumb (..)
) where
import LambdaCalculus.Evaluator.Base

View File

@ -1,46 +1,94 @@
module LambdaCalculus.Expression
( Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), DefF (..), VoidF, UnitF (..), Text
, substitute, substitute1, rename, free, bound, used
, Eval, EvalExpr, EvalX, EvalXF (..), Identity (..)
, pattern AppFE, pattern CtrE, pattern CtrFE,
pattern Cont, pattern ContF, pattern CallCC, pattern CallCCF
pattern ContE, pattern ContFE, pattern CallCCE, pattern CallCCFE
, Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..), NonEmpty (..), simplify
, pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF
, pattern PChar, pattern PCharF, pattern PStr, pattern PStrF
, ast2eval, eval2ast
, pattern PChar, pattern PCharF, pattern PStr, pattern PStrF, pattern HoleP, pattern HoleFP
, Check, CheckExpr, CheckExprF, CheckX, CheckXF (..)
, pattern AppFC, pattern CtrC, pattern CtrFC, pattern CallCCC, pattern CallCCFC
, pattern FixC, pattern FixFC, pattern HoleC, pattern HoleFC
, Type (..), TypeF (..), Scheme (..), tapp
, ast2check, ast2eval, check2eval, check2ast, eval2ast
) where
import LambdaCalculus.Evaluator.Base
import LambdaCalculus.Evaluator (alphaConvert, substitute)
import LambdaCalculus.Syntax.Base
import LambdaCalculus.Types.Base
import Data.Functor.Foldable (cata, hoist)
import Data.HashSet qualified as HS
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.List (foldl')
import Data.List.NonEmpty (toList)
import Data.Text (unpack)
-- | Convert from an abstract syntax tree to an evaluator expression.
ast2eval :: AST -> EvalExpr
ast2eval = substitute "callcc" CallCC . cata \case
builtins :: HashMap Text CheckExpr
builtins = HM.fromList [("callcc", CallCCC), ("fix", FixC)]
-- | Convert from an abstract syntax tree to a typechecker expression.
ast2check :: AST -> CheckExpr
ast2check = substitute builtins . cata \case
VarF name -> Var name
AppF ef exs -> foldl' App ef $ toList exs
AbsF ns e -> foldr Abs e $ toList ns
LetF ds e ->
let letExpr name val body' = App (Abs name body') val
in foldr (uncurry letExpr) e $ getNonEmptyDefFs ds
CtrF ctr es -> foldl' App (CtrE ctr) es
CtrF ctr es -> foldl' App (CtrC ctr) es
CaseF ps -> Case ps
PNatF n -> int2ast n
PListF es -> mkList es
PStrF s -> mkList $ map (App (CtrE CChar) . int2ast . fromEnum) $ unpack s
PCharF c -> App (CtrE CChar) (int2ast $ fromEnum c)
PStrF s -> mkList $ map (App (CtrC CChar) . int2ast . fromEnum) $ unpack s
PCharF c -> App (CtrC CChar) (int2ast $ fromEnum c)
HoleFP -> HoleC
where
int2ast :: Int -> EvalExpr
int2ast 0 = CtrE CZero
int2ast n = App (CtrE CSucc) (int2ast (n - 1))
int2ast :: Int -> CheckExpr
int2ast 0 = CtrC CZero
int2ast n = App (CtrC CSucc) (int2ast (n - 1))
mkList :: [EvalExpr] -> EvalExpr
mkList = foldr (App . App (CtrE CCons)) (CtrE CNil)
mkList :: [CheckExpr] -> CheckExpr
mkList = foldr (App . App (CtrC CCons)) (CtrC CNil)
-- | Convert from a typechecker expression to an evaluator expression.
check2eval :: CheckExpr -> EvalExpr
check2eval = cata \case
VarF name -> Var name
AppFC ef ex -> App ef ex
AbsF n e -> Abs n e
LetF (Def nx ex) e -> App (Abs nx e) ex
CtrFC ctr -> CtrE ctr
CaseF ps -> Case ps
CallCCFC -> CallCCE
FixFC -> z
HoleFC -> omega
where
z, omega :: EvalExpr
z = App omega $ Abs "fix" $ Abs "f" $ Abs "x" $
App (App (Var "f") (App (App (Var "fix") (Var "fix")) (Var "f"))) (Var "x")
omega = Abs "x" (App (Var "x") (Var "x"))
-- | Convert from an abstract syntax tree to an evaluator expression.
ast2eval :: AST -> EvalExpr
ast2eval = check2eval . ast2check
-- | Convert from a typechecker expression to an abstract syntax tree.
check2ast :: CheckExpr -> AST
check2ast = hoist go . rename (HM.keysSet builtins)
where
go :: CheckExprF r -> ASTF r
go = \case
VarF name -> VarF name
AppFC ef ex -> AppF ef (ex :| [])
AbsF n e -> AbsF (n :| []) e
LetF (Def nx ex) e -> LetFP ((nx, ex) :| []) e
CtrFC ctr -> CtrF ctr []
CaseF ps -> CaseF ps
CallCCFC-> VarF "callcc"
FixFC -> VarF "fix"
HoleFC -> HoleFP
-- | Convert from an evaluator expression to an abstract syntax tree.
eval2ast :: EvalExpr -> AST
@ -48,14 +96,14 @@ eval2ast :: EvalExpr -> AST
-- all instances of `callcc` must be bound;
-- therefore, we are free to alpha convert them,
-- freeing the name `callcc` for us to use for the built-in again.
eval2ast = hoist go . alphaConvert (HS.singleton "callcc")
eval2ast = hoist go . rename (HM.keysSet builtins)
where
go :: EvalExprF r -> ASTF r
go = \case
VarF name -> VarF name
CallCCF -> VarF "callcc"
AppFE ef ex -> AppF ef (ex :| [])
AbsF n e -> AbsF (n :| []) e
CtrFE ctr -> CtrF ctr []
CaseF ps -> CaseF ps
ContF e -> AbsF ("!" :| []) e
CallCCFE -> VarF "callcc"
ContFE e -> AbsF ("!" :| []) e

View File

@ -1,3 +1,4 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
module LambdaCalculus.Expression.Base
( Text, VoidF, UnitF (..), absurd'
@ -5,11 +6,24 @@ module LambdaCalculus.Expression.Base
, ExprF (..), PatF (..), DefF (..), AppArgsF, LetArgsF, CtrArgsF, XExprF
, RecursivePhase, projectAppArgs, projectLetArgs, projectCtrArgs, projectXExpr, projectDef
, embedAppArgs, embedLetArgs, embedCtrArgs, embedXExpr, embedDef
, Substitutable, free, bound, used, collectVars, rename, rename1
, substitute, substitute1, unsafeSubstitute, unsafeSubstitute1
, runRenamer, freshVar, replaceNames, runSubstituter, maySubstitute
) where
import Control.Monad.Reader (MonadReader, Reader, runReader, asks, local)
import Control.Monad.State (MonadState, StateT, evalStateT, state)
import Control.Monad.Zip (MonadZip, mzipWith)
import Data.Foldable (fold)
import Data.Functor.Foldable (Base, Recursive, Corecursive, project, embed)
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.HashSet (HashSet)
import Data.HashSet qualified as HS
import Data.Kind (Type)
import Data.Stream qualified as S
import Data.Text (Text)
import Data.Text qualified as T
data Expr phase
-- | A variable: `x`.
@ -100,8 +114,8 @@ type family LetArgsF phase :: Type -> Type
type family CtrArgsF phase :: Type -> Type
type family XExprF phase :: Type -> Type
data DefF r = DefF !Text !r
deriving (Eq, Functor, Show)
data DefF r = Def !Text !r
deriving (Eq, Functor, Foldable, Traversable, Show)
-- | A contractible data type with one extra type parameter.
data UnitF a = Unit
@ -199,10 +213,10 @@ class Functor (ExprF phase) => RecursivePhase phase where
embedXExpr = id
projectDef :: Def phase -> DefF (Expr phase)
projectDef = uncurry DefF
projectDef = uncurry Def
embedDef :: DefF (Expr phase) -> Def phase
embedDef (DefF n e) = (n, e)
embedDef (Def n e) = (n, e)
instance RecursivePhase phase => Recursive (Expr phase) where
project = \case
@ -227,3 +241,117 @@ instance RecursivePhase phase => Corecursive (Expr phase) where
---
--- End base functor boilerplate.
---
class Substitutable e where
-- | Fold over the variables in the expression with a monoid,
-- given what to do with variable usage sites and binding sites respectively.
collectVars :: Monoid m => (Text -> m) -> (Text -> m -> m) -> e -> m
-- | Free variables are variables which occur anywhere in an expression
-- where they are not bound by an abstraction.
free :: e -> HashSet Text
free = collectVars HS.singleton HS.delete
-- | Bound variables are variables which are abstracted over anywhere in an expression.
bound :: e -> HashSet Text
bound = collectVars (const HS.empty) HS.insert
-- | Used variables are variables which appear *anywhere* in an expression, free or bound.
used :: e -> HashSet Text
used = collectVars HS.singleton HS.insert
-- | Given a map between variable names and expressions,
-- replace each free occurrence of a variable with its respective expression.
substitute :: HashMap Text e -> e -> e
substitute substs = unsafeSubstitute substs . rename (foldMap free substs)
substitute1 :: Text -> e -> e -> e
substitute1 n e = substitute (HM.singleton n e)
-- | Rename all bound variables in an expression (both a binding sites and usage sites)
-- with new names where the new names are *not* members of the provided set.
rename :: HashSet Text -> e -> e
rename1 :: Text -> e -> e
rename1 n = rename (HS.singleton n)
-- | A variant of substitution which does *not* avoid variable capture;
-- it only gives the correct result if the bound variables in the body
-- are disjoint from the free variables in the argument.
unsafeSubstitute :: HashMap Text e -> e -> e
unsafeSubstitute1 :: Text -> e -> e -> e
unsafeSubstitute1 n e = unsafeSubstitute (HM.singleton n e)
--
-- These primitives are likely to be useful for implementing `rename`.
-- Ideally, I would like to find a way to move the implementation of `rename` here entirely,
-- but I haven't yet figured out an appropriate abstraction to do so.
--
-- | Run an action which requires a stateful context of used variable names
-- and a local context of variable replacements.
--
-- This is a useful monad for implementing the `rename` function.
runRenamer :: Substitutable e
=> (HashSet Text -> e -> StateT (HashSet Text) (Reader (HashMap Text Text)) a)
-> HashSet Text
-> e
-> a
runRenamer m ctx e = runReader (evalStateT (m ctx e) dirtyNames) HM.empty
where dirtyNames = HS.union ctx (used e)
-- | Create a new variable name within a context of used variable names.
freshVar :: MonadState (HashSet Text) m => Text -> m Text
freshVar baseName =
state \ctx -> let name = newName ctx in (name, HS.insert name ctx)
where
names = S.iterate (`T.snoc` '\'') baseName
newName ctx = S.head $ S.filter (not . (`HS.member` ctx)) names
-- | Replace a collection of old variable names with new variable names
-- and apply those replacements within a context.
replaceNames :: ( MonadReader (HashMap Text Text) m
, MonadState (HashSet Text) m
, MonadZip t, Traversable t
)
=> HashSet Text -> t Text -> m a -> m (t Text, a)
replaceNames badNames names m = do
newNames <- mapM freshVarIfNecessary names
let replacements = HM.filterWithKey (/=) $ fold $ mzipWith HM.singleton names newNames
x <- local (HM.union replacements) m
pure (newNames, x)
where
freshVarIfNecessary name
| name `HS.member` badNames = freshVar name
| otherwise = pure name
---
--- The same as the above section but for `substitute`.
--- This is useful when implementing substitution as a paramorphism.
---
-- | Run an action in a local context of substitutions.
--
-- This monad is useful for implementing, you guessed it, substitution.
runSubstituter :: (e -> Reader (HashMap Text e) a)
-> HashMap Text e
-> e
-> a
runSubstituter m substs e = runReader (m e) substs
-- | Apply only the substitutions which are not bound,
-- and only if there are substitutions left to apply.
maySubstitute :: ( MonadReader (HashMap Text b) m
, Functor t, Foldable t
)
=> t Text -> (a, m a) -> m a
maySubstitute ns (unmodified, substituted) =
local (compose $ fmap HM.delete ns) do
noMoreSubsts <- asks HM.null
if noMoreSubsts
then pure unmodified
else substituted
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id

View File

@ -1,8 +1,9 @@
module LambdaCalculus.Syntax.Base
( Expr (..), ExprF (..), Ctr (..), Pat, Def, DefF (..), PatF (..), VoidF, Text, NonEmpty (..)
, substitute, substitute1, rename, rename1, free, bound, used
, Parse, AST, ASTF, ASTX, ASTXF (..), NonEmptyDefFs (..)
, pattern LetFP, pattern PNat, pattern PNatF, pattern PList, pattern PListF
, pattern PChar, pattern PCharF, pattern PStr, pattern PStrF
, pattern PChar, pattern PCharF, pattern PStr, pattern PStrF, pattern HoleP, pattern HoleFP
, simplify
) where
@ -53,6 +54,8 @@ data ASTXF r
| PChar_ Char
-- | A string literal, e.g. `"abcd"`.
| PStr_ Text
-- | A type hole.
| HoleP_
deriving (Eq, Functor, Foldable, Traversable, Show)
instance RecursivePhase Parse where
@ -62,7 +65,6 @@ instance RecursivePhase Parse where
newtype NonEmptyDefFs r = NonEmptyDefFs { getNonEmptyDefFs :: NonEmpty (Text, r) }
deriving (Eq, Functor, Foldable, Traversable, Show)
pattern LetFP :: NonEmpty (Text, r) -> r -> ASTF r
pattern LetFP ds e = LetF (NonEmptyDefFs ds) e
@ -90,10 +92,18 @@ pattern PStrF s = ExprXF (PStr_ s)
pattern PStr :: Text -> AST
pattern PStr s = ExprX (PStr_ s)
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, ExprXF #-}
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, PNat, PList, PChar, PStr #-}
{-# COMPLETE VarF, AppF, AbsF, LetF , CtrF, CaseF, PNatF, PListF, PCharF, PStrF #-}
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, PNatF, PListF, PCharF, PStrF #-}
pattern HoleP :: AST
pattern HoleP = ExprX HoleP_
pattern HoleFP :: ASTF r
pattern HoleFP = ExprXF HoleP_
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, ExprXF #-}
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, PNat, PList, PChar, PStr, HoleP #-}
{-# COMPLETE VarF, AppF, AbsF, LetF , CtrF, CaseF, PNatF, PListF, PCharF, PStrF, HoleFP #-}
{-# COMPLETE VarF, AppF, AbsF, LetFP, CtrF, CaseF, PNatF, PListF, PCharF, PStrF, HoleFP #-}
-- TODO: Implement Substitutable for AST.
-- | Combine nested expressions into compound expressions or literals when possible.
simplify :: AST -> AST

View File

@ -86,10 +86,10 @@ ctr = pair <|> unit <|> either <|> nat <|> list <|> str
succ = Ctr CSucc [] <$ keyword "S"
natLit = (PNat . read <$> many1 digit) <* spaces
list = cons <|> consCtr <|> listLit
consCtr = Ctr CCons [] <$ keyword "(:)"
consCtr = Ctr CCons [] <$ keyword "(::)"
cons = try $ between (token '(') (token ')') do
e1 <- ambiguous
token ':'
keyword "::"
e2 <- ambiguous
pure $ Ctr CCons [e1, e2]
listLit = fmap PList $ between (token '[') (token ']') $ sepEndBy ambiguous (token ',')
@ -105,35 +105,36 @@ pat = label "case alternate" $ do
keyword "->"
e <- ambiguous
pure $ Pat c ns e
where pair = try $ between (token '(') (token ')') do
e1 <- identifier
token ','
e2 <- identifier
pure (CPair, [e1, e2])
unit = (CUnit, []) <$ keyword "()"
left = do
keyword "Left"
e <- identifier
pure (CLeft, [e])
right = do
keyword "Right"
e <- identifier
pure (CRight, [e])
zero = (CZero, []) <$ keyword "Z"
succ = do
keyword "S"
e <- identifier
pure (CSucc, [e])
nil = (CNil, []) <$ keyword "[]"
cons = try $ between (token '(') (token ')') do
e1 <- identifier
token ':'
e2 <- identifier
pure (CCons, [e1, e2])
char' = do
keyword "Char"
e <- identifier
pure (CChar, [e])
where
pair = try $ between (token '(') (token ')') do
e1 <- identifier
token ','
e2 <- identifier
pure (CPair, [e1, e2])
unit = (CUnit, []) <$ keyword "()"
left = do
keyword "Left"
e <- identifier
pure (CLeft, [e])
right = do
keyword "Right"
e <- identifier
pure (CRight, [e])
zero = (CZero, []) <$ keyword "Z"
succ = do
keyword "S"
e <- identifier
pure (CSucc, [e])
nil = (CNil, []) <$ keyword "[]"
cons = try $ between (token '(') (token ')') do
e1 <- identifier
keyword "::"
e2 <- identifier
pure (CCons, [e1, e2])
char' = do
keyword "Char"
e <- identifier
pure (CChar, [e])
case_ :: Parser AST
case_ = label "case patterns" $ do
@ -142,13 +143,16 @@ case_ = label "case patterns" $ do
token '}'
pure $ Case pats
hole :: Parser AST
hole = label "hole" $ HoleP <$ token '_'
-- | Guaranteed to consume a finite amount of input
finite :: Parser AST
finite = label "finite expression" $ variable <|> ctr <|> case_ <|> grouping
finite = label "finite expression" $ variable <|> hole <|> ctr <|> case_ <|> grouping
-- | Guaranteed to consume input, but may continue until it reaches a terminator
block :: Parser AST
block = label "block expression" $ finite <|> abstraction <|> let_
block = label "block expression" $ abstraction <|> let_ <|> finite
-- | Not guaranteed to consume input at all, may continue until it reaches a terminator
ambiguous :: Parser AST

View File

@ -71,6 +71,7 @@ unparseAST = toStrict . toLazyText . snd . cata \case
in tag Finite $ "[" <> es' <> "]"
PStrF s -> tag Finite $ "\"" <> fromText s <> "\""
PCharF c -> tag Finite $ "'" <> fromLazyText (singleton c)
HoleFP -> tag Finite "_"
where
unparseApp :: Tagged Builder -> NonEmpty (Tagged Builder) -> Tagged Builder
unparseApp ef (unsnoc -> (exs, efinal))
@ -78,8 +79,8 @@ unparseAST = toStrict . toLazyText . snd . cata \case
unparseCtr :: Ctr -> [Tagged Builder] -> Tagged Builder
-- Fully-applied special syntax forms
unparseCtr CPair [x, y] = tag Finite $ "(" <> unambiguous x <> ", " <> unambiguous y <> ")"
unparseCtr CCons [x, y] = tag Finite $ "(" <> unambiguous x <> " : " <> unambiguous y <> ")"
unparseCtr CPair [x, y] = tag Finite $ "(" <> unambiguous x <> ", " <> unambiguous y <> ")"
unparseCtr CCons [x, y] = tag Finite $ "(" <> unambiguous x <> " :: " <> unambiguous y <> ")"
-- Partially-applied syntax forms
unparseCtr CUnit [] = tag Finite "()"
unparseCtr CPair [] = tag Finite "(,)"
@ -88,7 +89,7 @@ unparseAST = toStrict . toLazyText . snd . cata \case
unparseCtr CZero [] = tag Finite "Z"
unparseCtr CSucc [] = tag Finite "S"
unparseCtr CNil [] = tag Finite "[]"
unparseCtr CCons [] = tag Finite "(:)"
unparseCtr CCons [] = tag Finite "(::)"
unparseCtr CChar [] = tag Finite "Char"
unparseCtr ctr (x:xs) = unparseApp (unparseCtr ctr []) (x :| xs)

158
src/LambdaCalculus/Types.hs Normal file
View File

@ -0,0 +1,158 @@
module LambdaCalculus.Types
( module LambdaCalculus.Types.Base
, infer
) where
import LambdaCalculus.Types.Base
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.Except (MonadError, throwError)
import Control.Monad.Reader (MonadReader, runReader, asks, local)
import Control.Monad.RWS
( RWST, evalRWST
, MonadState, state
, MonadWriter, tell, listen
)
import Data.Foldable (forM_, toList)
import Data.HashSet qualified as HS
import Data.HashMap.Strict qualified as HM
import Data.Stream (Stream (..), fromList)
import Data.Text qualified as T
fresh :: MonadState (Stream Text) m => m Type
fresh = state \(Cons n ns) -> (TVar n, ns)
inst :: MonadState (Stream Text) m => Scheme -> m Type
inst (TForall ns t) = foldr (\n_n t' -> substitute1 n_n <$> fresh <*> t') (pure t) ns
lookupVar :: (MonadReader Context m, MonadState (Stream Text) m, MonadError Text m) => Text -> m Type
lookupVar n = do
t_polyq <- asks (HM.!? n)
case t_polyq of
Nothing -> throwError $ "Variable not bound: " <> n
Just t_poly -> inst t_poly
generalize :: MonadReader Context m => Type -> m Scheme
generalize t = do
ctx <- asks HM.keysSet
pure $ TForall (toList $ HS.difference (free t) ctx) t
bindVar :: MonadReader Context m => Text -> Type -> m a -> m a
bindVar n t = local (HM.insert n (TForall [] t))
unify :: MonadWriter [Constraint] m => Type -> Type -> m ()
unify t1 t2 = tell [(t1, t2)]
ctrTy :: MonadState (Stream Text) m => Ctr -> m (Type, [Type])
ctrTy = \case
CUnit -> pure (TUnit, [])
CZero -> pure (TNat, [])
CSucc -> pure (TNat, [TNat])
CChar -> pure (TChar, [TNat])
CNil -> mkUnary TList $ const []
CCons -> mkUnary TList \t_a -> [t_a, TApp TList t_a]
CPair -> mkBinary TProd \t_a t_b -> [t_a, t_b]
CLeft -> mkBinary TSum \t_a _ -> [t_a]
CRight -> mkBinary TSum \_ t_b -> [t_b]
where
mkBinary tc tcas = do
t_a <- fresh
t_b <- fresh
pure (tapp [tc, t_a, t_b], tcas t_a t_b)
mkUnary tc tcas = do
t_a <- fresh
pure (TApp tc t_a, tcas t_a)
j :: (MonadError Text m, MonadReader Context m, MonadState (Stream Text) m, MonadWriter [Constraint] m)
=> CheckExpr -> m Type
j (Var name) = lookupVar name
j (App e_fun e_arg) = do
t_ret <- fresh
t_fun <- j e_fun
t_arg <- j e_arg
unify t_fun (tapp [TAbs, t_arg, t_ret])
pure t_ret
j (Abs n_arg e_ret) = do
t_arg <- fresh
t_ret <- bindVar n_arg t_arg $ j e_ret
pure $ tapp [TAbs, t_arg, t_ret]
j (Let (n_x, e_x) e_ret) = do
(t_x_mono, c) <- listen $ j e_x
s <- solve' c
t_x_poly <- generalize $ substitute s t_x_mono
local (HM.insert n_x t_x_poly) $ j e_ret
-- In a case expression:
-- * the pattern for each branch has the same type as the expression being matched, and
-- * the return type for each branch has the same type as the return type of the case expression as a whole.
j (Case ctrs) = do
t_ret <- fresh
t_x <- fresh
forM_ ctrs \(Pat ctr ns_n e) -> do
(t_x', ts_n) <- ctrTy ctr
unify t_x t_x'
when (length ts_n /= length ns_n) $ throwError "Constructor arity mismatch"
t_ret' <- local (HM.union $ HM.fromList $ zip ns_n $ map (TForall []) ts_n) $ j e
unify t_ret t_ret'
pure $ tapp [TAbs, t_x, t_ret]
j (CtrC ctr) = do
(t_ret, ts_n) <- ctrTy ctr
pure $ foldr (\t_a t_r -> tapp [TAbs, t_a, t_r]) t_ret ts_n
j CallCCC = do
t_a <- fresh
t_b <- fresh
pure $ tapp [TAbs, tapp [TAbs, tapp [TAbs, t_a, t_b], t_a], t_a]
j FixC = do
t_a <- fresh
pure $ tapp [TAbs, tapp [TAbs, t_a, t_a], t_a]
j HoleC = asks show >>= throwError . (<>) "Encountered hole with context: " . T.pack
occurs :: Text -> Type -> Bool
occurs n t = HS.member n (free t)
findDifference :: MonadError (Type, Type) m => Type -> Type -> m (Maybe (Text, Type))
findDifference t1 t2
| t1 == t2 = pure Nothing
| TVar n1 <- t1, not (occurs n1 t2) = pure $ Just (n1, t2)
| TVar _ <- t2 = findDifference t2 t1
| TApp a1 b1 <- t1, TApp a2 b2 <- t2 = (<|>) <$> findDifference a1 a2 <*> findDifference b1 b2
| otherwise = throwError (t1, t2)
unifies :: MonadError (Type, Type) m => Type -> Type -> m Substitution
unifies t1 t2 = do
dq <- findDifference t1 t2
case dq of
Nothing -> pure HM.empty
Just s -> do
ss <- unifies (uncurry substitute1 s t1) (uncurry substitute1 s t2)
pure $ uncurry HM.insert (fmap (substitute ss) s) ss
solve :: MonadError (Type, Type) m => [Constraint] -> m Substitution
solve [] = pure HM.empty
solve (c:cs) = do
s <- uncurry unifies c
ss <- solve (substituteMono s cs)
pure $ HM.union ss (substituteMono ss s)
solve' :: MonadError Text m => [Constraint] -> m Substitution
solve' c = case solve c of
Right ss -> pure ss
Left (t1, t2) -> throwError $ "Could not unify " <> unparseType t1 <> " with " <> unparseType t2
type Inferencer a = RWST Context [Constraint] (Stream Text) (Either Text) a
runInferencer :: Inferencer a -> Either Text (a, [Constraint])
runInferencer m = evalRWST m HM.empty freshNames
where
freshNames = fromList $ do
n <- [0 :: Int ..]
c <- ['a'..'z']
pure $ T.pack if n == 0 then [c] else c : show n
infer :: CheckExpr -> Either Text Scheme
infer e = do
(t, c) <- runInferencer $ j e
s <- solve' c
let t' = substitute s t
pure $ runReader (generalize t') HM.empty

View File

@ -0,0 +1,242 @@
{-# LANGUAGE TemplateHaskell #-}
module LambdaCalculus.Types.Base
( Identity (..)
, Expr (..), Ctr (..), Pat, ExprF (..), PatF (..), VoidF, UnitF (..), Text
, substitute, substitute1, rename, rename1, free, bound, used
, Check, CheckExpr, CheckExprF, CheckX, CheckXF (..)
, pattern AppFC, pattern CtrC, pattern CtrFC, pattern CallCCC, pattern CallCCFC
, pattern FixC, pattern FixFC, pattern HoleC, pattern HoleFC
, Type (..), TypeF (..), Scheme (..), tapp
, Substitution, Context, Constraint
, MonoSubstitutable, substituteMono, substituteMono1
, unparseType, unparseScheme
) where
import Control.Monad (forM)
import Control.Monad.Reader (asks)
import Data.Bifunctor (bimap, first)
import Data.Foldable (fold)
import Data.Functor.Foldable (embed, cata, para)
import Data.Functor.Foldable.TH (makeBaseFunctor)
import Data.Functor.Identity (Identity (..))
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.List (foldl1')
import Data.Text qualified as T
import Data.Traversable (for)
import LambdaCalculus.Expression.Base
data Check
type CheckExpr = Expr Check
type instance AppArgs Check = CheckExpr
type instance AbsArgs Check = Text
type instance LetArgs Check = (Text, CheckExpr)
type instance CtrArgs Check = UnitF CheckExpr
type instance XExpr Check = CheckX
type CheckX = CheckXF CheckExpr
type CheckExprF = ExprF Check
type instance AppArgsF Check = Identity
type instance LetArgsF Check = DefF
type instance CtrArgsF Check = UnitF
type instance XExprF Check = CheckXF
data CheckXF r
-- | Call-with-current-continuation.
= CallCCC_
-- | A fixpoint combinator,
-- because untyped lambda calculus fixpoint combinators won't typecheck.
| FixC_
-- | A hole to ask the type inferencer about the context for debugging purposes.
| HoleC_
deriving (Eq, Functor, Foldable, Traversable, Show)
pattern CtrC :: Ctr -> CheckExpr
pattern CtrC c = Ctr c Unit
pattern CtrFC :: Ctr -> CheckExprF r
pattern CtrFC c = CtrF c Unit
pattern AppFC :: r -> r -> CheckExprF r
pattern AppFC ef ex = AppF ef (Identity ex)
pattern CallCCC :: CheckExpr
pattern CallCCC = ExprX CallCCC_
pattern CallCCFC :: CheckExprF r
pattern CallCCFC = ExprXF CallCCC_
pattern FixC :: CheckExpr
pattern FixC = ExprX FixC_
pattern FixFC :: CheckExprF r
pattern FixFC = ExprXF FixC_
pattern HoleC :: CheckExpr
pattern HoleC = ExprX HoleC_
pattern HoleFC :: CheckExprF r
pattern HoleFC = ExprXF HoleC_
{-# COMPLETE Var, App, Abs, Let, CtrC, Case, ExprX #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFC, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrF, CaseF, ExprXF #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrFC, CaseF, ExprXF #-}
{-# COMPLETE Var, App, Abs, Let, Ctr, Case, CallCCC, FixC, HoleC #-}
{-# COMPLETE Var, App, Abs, Let, CtrC, Case, CallCCC, FixC, HoleC #-}
{-# COMPLETE VarF, AppF, AbsF, LetF, CtrFC, CaseF, CallCCFC, FixFC, HoleFC #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrF, CaseF, CallCCFC, FixFC, HoleFC #-}
{-# COMPLETE VarF, AppFC, AbsF, LetF, CtrFC, CaseF, CallCCFC, FixFC, HoleFC #-}
instance RecursivePhase Check where
projectAppArgs = Identity
projectLetArgs = projectDef
embedAppArgs = runIdentity
embedLetArgs = embedDef
instance Substitutable CheckExpr where
collectVars withVar withBinder = cata \case
VarF n -> withVar n
AbsF n e -> withBinder n e
LetF (Def n x) e -> x <> withBinder n e
CaseF pats -> foldMap (\(Pat _ ns e) -> foldr withBinder e ns) pats
e -> fold e
rename = runRenamer $ \badNames -> cata \case
VarF n -> asks $ Var . HM.findWithDefault n n
AbsF n e -> uncurry Abs . first runIdentity <$> replaceNames badNames (Identity n) e
LetF (Def n x) e -> do
x' <- x
(Identity n', e') <- replaceNames badNames (Identity n) e
pure $ Let (n', x') e'
CaseF ps -> Case <$> forM ps \(Pat ctr ns e) ->
uncurry (Pat ctr) <$> replaceNames badNames ns e
e -> embed <$> sequenceA e
unsafeSubstitute = runSubstituter $ para \case
VarF name -> asks $ HM.findWithDefault (Var name) name
AbsF name e -> Abs name <$> maySubstitute (Identity name) e
LetF (Def name (_, x)) e -> do
x' <- x
e' <- maySubstitute (Identity name) e
pure $ Let (name, x') e'
CaseF pats -> Case <$> for pats \(Pat ctr ns e) -> Pat ctr ns <$> maySubstitute ns e
e -> embed <$> traverse snd e
-- | A monomorphic type.
data Type
-- | Type variable.
= TVar Text
-- | Type application.
| TApp Type Type
-- | The function type.
| TAbs
-- | The product type.
| TProd
-- | The sum type.
| TSum
-- | The unit type.
| TUnit
-- | The empty type.
| TVoid
-- | The type of natural numbers.
| TNat
-- | The type of lists.
| TList
-- | The type of characters.
| TChar
deriving (Eq, Show)
makeBaseFunctor ''Type
instance Substitutable Type where
collectVars withVar _ = cata \case
TVarF n -> withVar n
t -> fold t
-- /All/ variables in a monomorphic type are free.
rename _ t = t
-- No renaming step is necessary.
substitute substs = cata \case
TVarF n -> HM.findWithDefault (TVar n) n substs
e -> embed e
unsafeSubstitute = substitute
-- | A polymorphic type.
data Scheme
-- | Universally quantified type variables.
= TForall [Text] Type
deriving (Eq, Show)
instance Substitutable Scheme where
collectVars withVar withBinder (TForall names t) =
foldMap withBinder names $ collectVars withVar withBinder t
rename = runRenamer \badNames (TForall names t) ->
uncurry TForall <$> replaceNames badNames names (pure t)
-- I took a shot at implementing this but found it to be quite difficult
-- because merging the foralls is tricky.
-- It's not undoable, but it wasn't worth my further time investment
-- seeing as this function isn't currently used anywhere.
unsafeSubstitute = error "Substitution for schemes not yet implemented"
type Substitution = HashMap Text Type
type Context = HashMap Text Scheme
type Constraint = (Type, Type)
class MonoSubstitutable t where
substituteMono :: Substitution -> t -> t
substituteMono1 :: Text -> Type -> t -> t
substituteMono1 var val= substituteMono (HM.singleton var val)
instance MonoSubstitutable Type where
substituteMono = substitute
instance MonoSubstitutable Scheme where
substituteMono substs (TForall names t) =
TForall names $ substitute (foldMap HM.delete names substs) t
instance MonoSubstitutable Constraint where
substituteMono substs = bimap (substituteMono substs) (substituteMono substs)
instance MonoSubstitutable t => MonoSubstitutable [t] where
substituteMono = fmap . substituteMono
instance MonoSubstitutable t => MonoSubstitutable (HashMap a t) where
substituteMono = fmap . substituteMono
tapp :: [Type] -> Type
tapp [] = error "Empty type applications are not permitted"
tapp [t] = t
tapp ts = foldl1' TApp ts
-- HACK
pattern TApp2 :: Type -> Type -> Type -> Type
pattern TApp2 tf tx ty = TApp (TApp tf tx) ty
-- TODO: Improve these printers.
unparseType :: Type -> Text
unparseType (TVar name) = name
unparseType (TApp2 TAbs a b) = "(" <> unparseType a <> " -> " <> unparseType b <> ")"
unparseType (TApp2 TProd a b) = "(" <> unparseType a <> " * " <> unparseType b <> ")"
unparseType (TApp2 TSum a b) = "(" <> unparseType a <> " + " <> unparseType b <> ")"
unparseType (TApp TList a) = "[" <> unparseType a <> "]"
unparseType (TApp a b) = "(" <> unparseType a <> " " <> unparseType b <> ")"
unparseType TAbs = "(->)"
unparseType TProd = "(*)"
unparseType TSum = "(+)"
unparseType TUnit = ""
unparseType TVoid = ""
unparseType TNat = "Nat"
unparseType TList = "[]"
unparseType TChar = "Char"
unparseScheme :: Scheme -> Text
unparseScheme (TForall [] t) = unparseType t
unparseScheme (TForall names t) = "" <> T.unwords names <> ". " <> unparseType t

View File

@ -39,10 +39,10 @@ omega = App x x
where x = Abs "x" (App (Var "x") (Var "x"))
cc1 :: EvalExpr
cc1 = App CallCC (Abs "k" (App omega (App (Var "k") (Var "z"))))
cc1 = App CallCCE (Abs "k" (App omega (App (Var "k") (Var "z"))))
cc2 :: EvalExpr
cc2 = App (Var "y") (App CallCC (Abs "k" (App (Var "z") (App (Var "k") (Var "x")))))
cc2 = App (Var "y") (App CallCCE (Abs "k" (App (Var "z") (App (Var "k") (Var "x")))))
main :: IO ()
main = defaultMain $