Skip to content

Commit

Permalink
Improve type inference of functions (#9)
Browse files Browse the repository at this point in the history
* Improve type inference of functions

* Handle type annotations on functions

* Handle return statements and support if-only statements

* handle if-else chains
  • Loading branch information
kevinbarabash committed Nov 8, 2023
1 parent c351278 commit 22b4a64
Show file tree
Hide file tree
Showing 9 changed files with 664 additions and 97 deletions.
21 changes: 11 additions & 10 deletions src/Escalier.Data/Escalier.Data.fsproj
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<Compile Include="Library.fs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="FParsec" Version="1.1.1" />
</ItemGroup>
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<Compile Include="Library.fs"/>
<Compile Include="Visitor.fs"/>
</ItemGroup>
<ItemGroup>
<PackageReference Include="FParsec" Version="1.1.1"/>
</ItemGroup>
</Project>
45 changes: 33 additions & 12 deletions src/Escalier.Data/Library.fs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ module Syntax =
| For of leff: Pattern * right: Expr * body: Block
| Return of option<Expr>
| Decl of Decl
| Assign

type Stmt = { span: Span; kind: StmtKind }

Expand All @@ -39,7 +38,8 @@ module Syntax =
match this with
| Number(value) -> value
| String(value) -> $"\"{value}\""
| Boolean(value) -> value |> string
| Boolean(true) -> "true"
| Boolean(false) -> "false"
| Null -> "null"
| Undefined -> "undefined"

Expand Down Expand Up @@ -67,6 +67,7 @@ module Syntax =
| Tuple of elems: list<Pattern>
| Wildcard
| Literal of span: Span * value: Literal
// TODO: get rid of `is_mut` since it's covered by `ident: BindingIdent`
| Is of span: Span * ident: BindingIdent * is_name: string * is_mut: bool

type Pattern =
Expand Down Expand Up @@ -125,6 +126,18 @@ module Syntax =
{ parts: list<string>
exprs: list<Expr> }

type FuncParam<'T> =
{ pattern: Pattern
typeAnn: 'T
optional: bool }

type Function =
{ param_list: list<FuncParam<option<TypeAnn>>>
return_type: option<TypeAnn>
type_params: option<list<TypeParam>>
throws: option<TypeAnn>
body: BlockOrExpr }

type ExprKind =
| Identifer of string
| Literal of Literal
Expand All @@ -133,7 +146,7 @@ module Syntax =
| Assign of left: Expr * op: AssignOp * right: Expr
| Binary of left: Expr * op: BinaryOp * right: Expr
| Unary of op: string * value: Expr
| Function of param_list: list<string> * body: BlockOrExpr
| Function of Function
| Call of
callee: Expr *
type_args: option<list<TypeAnn>> *
Expand All @@ -142,11 +155,14 @@ module Syntax =
throws: option<Type.Type>
| Index of target: Expr * index: Expr * opt_chain: bool
| Member of target: Expr * name: string * opt_chain: bool
| If of cond: Expr * then_branch: BlockOrExpr * else_branch: BlockOrExpr
| IfElse of
cond: Expr *
then_branch: BlockOrExpr *
else_branch: option<BlockOrExpr>
| Match of target: Expr * cases: list<MatchCase>
| Try of body: Block * catch: option<Expr> * finally_: option<Expr>
| Do of body: Block
| Await of value: Expr
| Await of value: Expr // TODO: convert rejects to throws
| Throw of value: Expr
| TemplateLiteral of TemplateLiteral
| TaggedTemplateLiteral of
Expand All @@ -159,7 +175,12 @@ module Syntax =
span: Span
mutable inferred_type: option<Type.Type> }

type ObjTypeAnnElem = int
type ObjTypeAnnElem =
| Callable of Function
| Constructor of Function
| Method of name: string * is_mut: bool * type_: Function
| Getter of name: string * return_type: TypeAnn * throws: TypeAnn
| Setter of name: string * param: FuncParam<TypeAnn> * throws: TypeAnn

type KeywordTypeAnn =
| Boolean
Expand All @@ -175,14 +196,14 @@ module Syntax =
type TypeParam =
{ span: Span
name: string
bound: option<Type.Type>
default_: option<Type.Type> }
bound: option<TypeAnn>
default_: option<TypeAnn> }

type FunctionType =
{ type_params: option<list<TypeParam>>
params_: list<TypeAnn>
params_: list<FuncParam<TypeAnn>>
return_type: TypeAnn
throws: option<Type.Type> }
throws: option<TypeAnn> }

type ConditionType =
{ check: TypeAnn
Expand Down Expand Up @@ -245,7 +266,7 @@ module Type =
| Tuple of elems: list<Pattern>
| Wildcard
| Literal of Syntax.Literal
| Is of target: Pattern * id: string
| Is of target: Syntax.BindingIdent * id: string
| Rest of target: Pattern

override this.ToString() =
Expand Down Expand Up @@ -421,7 +442,7 @@ module Type =
| TypeRef of
name: string *
type_args: option<list<Type>> *
scheme: option<Scheme>
scheme: option<Scheme> // used so that we can reference a type ref's scheme without importing it
| Literal of Syntax.Literal
| Primitive of Primitive
| Tuple of list<Type>
Expand Down
185 changes: 185 additions & 0 deletions src/Escalier.Data/Visitor.fs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
namespace Escalier.Data

open Escalier.Data.Syntax
open Escalier.Data.Type

module Visitor =
let rec walk_type (v: Type -> unit) (t: Type) : unit =
v t

match t.kind with
| Array elem -> walk_type v elem
| TypeVar tv ->
maybe_walk_type v tv.instance |> ignore
maybe_walk_type v tv.bound |> ignore
| TypeRef(_, typeArgs, scheme) ->
Option.map (List.iter (walk_type v)) typeArgs |> ignore

Option.map
(fun (scheme: Scheme) ->
walk_type v scheme.type_
walk_type_params v scheme.type_params)
scheme
|> ignore
| Literal _ -> () // leaf node
| Primitive _ -> () // leaf node
| Tuple types -> List.iter (walk_type v) types
| Union types -> List.iter (walk_type v) types
| Intersection types -> List.iter (walk_type v) types
| Keyword _ -> () // leaf node
| Function f ->
List.iter (walk_type v) (List.map (fun p -> p.type_) f.param_list)
| Object objTypeElems ->
for elem in objTypeElems do
match elem with
| Callable(callable) ->
List.iter
(walk_type v)
(List.map (fun p -> p.type_) callable.param_list)

walk_type v callable.return_type
| _ -> ()
| Rest t -> walk_type v t
| KeyOf t -> walk_type v t
| Index(target, index) ->
walk_type v target
walk_type v index
| Condition(check, extends, trueType, falseType) ->
walk_type v check
walk_type v extends
walk_type v trueType
walk_type v falseType
| Infer _ -> () // leaf node
| Wildcard -> () // leaf node
| Binary(left, _op, right) ->
walk_type v left
walk_type v right

and maybe_walk_type (v: Type -> unit) (ot: option<Type>) : unit =
Option.map (walk_type v) ot |> ignore

and walk_func (v: Type -> unit) (f: Function) : unit =
List.iter (walk_type v) (List.map (fun p -> p.type_) f.param_list)
walk_type v f.return_type
walk_type v f.throws
Option.map (walk_type_params v) f.type_params |> ignore

and walk_type_params (v: Type -> unit) (tp: list<TypeParam>) : unit =
List.iter
(fun (tp: TypeParam) ->
maybe_walk_type v tp.bound
maybe_walk_type v tp.default_)
tp

type SyntaxVisitor() =
abstract member VisitExpr: Expr -> unit
abstract member VisitStmt: Stmt -> unit
abstract member VisitTypeAnn: TypeAnn -> unit
abstract member VisitPattern: Syntax.Pattern -> unit
abstract member VisitBlock: Block -> unit
abstract member VisitScript: Script -> unit

default this.VisitExpr(e: Expr) =
match e.kind with
| ExprKind.Assign(left, _op, right) ->
this.VisitExpr left
this.VisitExpr right
| ExprKind.Identifer _ -> ()
| ExprKind.Literal _ -> ()
| ExprKind.Object elems ->
List.iter
(fun (elem: ObjElem) ->
// TODO: add support for computed keys
match elem with
| ObjElem.Property(_span, _key, value) -> this.VisitExpr value
| ObjElem.Spread(_span, value) -> this.VisitExpr value)
elems
| ExprKind.Tuple elems -> List.iter this.VisitExpr elems
| ExprKind.Binary(left, _, right) ->
this.VisitExpr left
this.VisitExpr right
| ExprKind.Unary(_op, value) -> this.VisitExpr value
| ExprKind.Function f ->
match f.body with
| BlockOrExpr.Expr(e) -> this.VisitExpr e
| BlockOrExpr.Block(b) -> this.VisitBlock b
| ExprKind.Call(callee, typeArgs, args, _optChain, _throws) ->
this.VisitExpr callee

match typeArgs with
| Some(typeArgs) -> List.iter this.VisitTypeAnn typeArgs
| None -> ()

List.iter this.VisitExpr args
| ExprKind.Index(target, index, _optChain) ->
this.VisitExpr target
this.VisitExpr index
| ExprKind.Member(target, _name, _optChain) -> this.VisitExpr target
| ExprKind.IfElse(cond, thenBranch, elseBranch) ->
this.VisitExpr cond

match thenBranch with
| BlockOrExpr.Expr(e) -> this.VisitExpr e
| BlockOrExpr.Block(b) -> this.VisitBlock b

match elseBranch with
| Some(BlockOrExpr.Expr(e)) -> this.VisitExpr e
| Some(BlockOrExpr.Block(b)) -> this.VisitBlock b
| _ -> ()
| ExprKind.Match(target, cases) ->
this.VisitExpr target

List.iter
(fun (case: MatchCase) ->
this.VisitPattern case.pattern
this.MaybeWalkExpr case.guard
this.VisitExpr case.body)
cases
| ExprKind.Try(body, catch, ``finally``) -> failwith "todo"
| ExprKind.Do body -> failwith "todo"
| ExprKind.Await value -> this.VisitExpr value
| ExprKind.Throw value -> this.VisitExpr value
| ExprKind.TemplateLiteral templateLiteral -> failwith "todo"
| ExprKind.TaggedTemplateLiteral(tag, template, throws) -> failwith "todo"

member this.MaybeWalkExpr(e: option<Expr>) =
match e with
| Some(e) -> this.VisitExpr e
| None -> ()

default this.VisitStmt(s: Stmt) =
match s.kind with
| Decl decl ->
match decl.kind with
| DeclKind.VarDecl(pattern, exprOption, typeAnnOption, isDeclare) ->
this.VisitPattern pattern
this.MaybeWalkExpr exprOption
this.MaybeWalkTypeAnn typeAnnOption
| DeclKind.TypeDecl(name, typeAnn, typeParamsOption) ->
this.VisitTypeAnn typeAnn

match typeParamsOption with
| Some(typeParams) ->
List.iter
(fun (typeParam: Syntax.TypeParam) ->
this.MaybeWalkTypeAnn typeParam.bound
this.MaybeWalkTypeAnn typeParam.default_)
typeParams
| None -> ()
| Expr expr -> this.VisitExpr expr
| For(left, right, body) ->
this.VisitPattern left
this.VisitExpr right
List.iter this.VisitStmt body.stmts
| Return exprOption -> this.MaybeWalkExpr exprOption

default this.VisitTypeAnn(ta: TypeAnn) = ()

member this.MaybeWalkTypeAnn(ota: option<TypeAnn>) =
match ota with
| Some(ta) -> this.VisitTypeAnn ta
| None -> ()

default this.VisitPattern(p: Syntax.Pattern) = ()
default this.VisitBlock(b: Block) = List.iter this.VisitStmt b.stmts
default this.VisitScript(s: Script) = ()
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
input: fn (x, y) { x }
output: Success: { kind =
Function
(["x"; "y"],
Block { span = { start = (Ln: 1, Col: 11)
stop = (Ln: 1, Col: 16) }
stmts = [{ span = { start = (Ln: 1, Col: 13)
stop = (Ln: 1, Col: 15) }
kind = Expr { kind = Identifer "x"
span = { start = (Ln: 1, Col: 13)
stop = (Ln: 1, Col: 15) }
inferred_type = None } }] })
{ param_list =
[{ pattern = { kind = Identifier { span = { start = (Ln: 1, Col: 5)
stop = (Ln: 1, Col: 6) }
name = "x"
isMut = false }
span = { start = (Ln: 1, Col: 5)
stop = (Ln: 1, Col: 6) }
inferred_type = None }
typeAnn = None
optional = false };
{ pattern = { kind = Identifier { span = { start = (Ln: 1, Col: 8)
stop = (Ln: 1, Col: 9) }
name = "y"
isMut = false }
span = { start = (Ln: 1, Col: 8)
stop = (Ln: 1, Col: 9) }
inferred_type = None }
typeAnn = None
optional = false }]
return_type = None
type_params = None
throws = None
body =
Block { span = { start = (Ln: 1, Col: 11)
stop = (Ln: 1, Col: 16) }
stmts = [{ span = { start = (Ln: 1, Col: 13)
stop = (Ln: 1, Col: 15) }
kind = Expr { kind = Identifer "x"
span = { start = (Ln: 1, Col: 13)
stop = (Ln: 1, Col: 15) }
inferred_type = None } }] } }
span = { start = (Ln: 1, Col: 1)
stop = (Ln: 1, Col: 16) }
inferred_type = None }
Loading

0 comments on commit 22b4a64

Please sign in to comment.