feat: type check

This commit is contained in:
2025-11-04 02:37:01 +08:00
parent 26357753f8
commit 7c9ba92769
22 changed files with 2418 additions and 88 deletions

View File

@@ -15,7 +15,7 @@ pub(all) enum Type {
} derive(Show)
///|
enum Literal {
pub(all) enum Literal {
Unit
Bool(Bool)
Int(Int)
@@ -23,20 +23,20 @@ enum Literal {
} derive(Show)
///|
enum AddSubOp {
pub(all) enum AddSubOp {
Add
Sub
} derive(Show)
///|
enum MulDivRemOp {
pub(all) enum MulDivRemOp {
Mul
Div
Rem
} derive(Show)
///|
enum Pattern {
pub(all) enum Pattern {
Wildcard
Identifier(String)
Literal(Literal)
@@ -45,7 +45,7 @@ enum Pattern {
} derive(Show)
///|
enum Expr {
pub(all) enum Expr {
Or(Expr, Expr)
And(Expr, Expr)
Compare(CompareOperator, Expr, Expr)
@@ -69,7 +69,7 @@ enum Expr {
} derive(Show)
///|
struct Function {
pub(all) struct Function {
id : String
user_defined_type : Type?
params : Array[(String, Type?)]
@@ -78,15 +78,15 @@ struct Function {
} derive(Show)
///|
enum Binding {
pub(all) enum Binding {
Identifier(String)
Wildcard
Tuple(Array[Binding])
} derive(Show)
///|
enum Stmt {
pub(all) enum Stmt {
Let(Binding, Type?, Expr)
LetTuple(Array[Binding], Type?, Expr)
LetMut(String, Type?, Expr)
Assign(Expr, Expr)
While(Expr, Array[Stmt])
@@ -96,28 +96,28 @@ enum Stmt {
} derive(Show)
///|
struct TopLet {
pub(all) struct TopLet {
id : String
type_ : Type?
expr : Expr
} derive(Show)
///|
struct StructDef {
pub(all) struct StructDef {
id : String
user_defined_type : Type?
fields : Array[(String, Type)]
} derive(Show)
///|
struct EnumDef {
pub(all) struct EnumDef {
id : String
user_defined_type : Type?
variants : Array[(String, Array[Type])]
} derive(Show)
///|
struct Program {
pub(all) struct Program {
top_lets : Map[String, TopLet]
top_functions : Map[String, Function]
struct_defs : Map[String, StructDef]
@@ -125,7 +125,7 @@ struct Program {
} derive(Show)
///|
fn parse_type(
pub fn parse_type(
tokens : ArrayView[Token],
) -> (Type, ArrayView[Token]) raise ParseError {
match tokens {
@@ -136,7 +136,7 @@ fn parse_type(
[Array, LBracket, .. rest] => {
let (elem_type, rest) = parse_type(rest)
guard rest is [RBracket, .. rest] else {
raise ParseError("Expected ']' after array type")
raise ParseError("Expected ']' after array type, got \{rest[0]}")
}
(Array(elem_type), rest)
}
@@ -156,7 +156,7 @@ fn parse_type(
[RParen, ..] => continue r
_ =>
raise ParseError(
"Expected ',' or ')' or ')' '->' in tuple/function type",
"Expected ',' or ')' or ')' '->' in tuple/function type, got \{r[0]}",
)
}
continue r
@@ -166,21 +166,26 @@ fn parse_type(
[UpperIdentifier(name), LBracket, .. rest] => {
let (type_arg, rest) = parse_type(rest)
guard rest is [RBracket, .. rest] else {
raise ParseError("Expected ']' after generic type argument")
raise ParseError(
"Expected ']' after generic type argument, got \{rest[0]}",
)
}
(Generic(name, type_arg), rest)
}
[UpperIdentifier(name), .. rest] => (UserDefined(name), rest)
_ => raise ParseError("Unexpected token while parsing type")
tokens =>
raise ParseError("Unexpected token while parsing type: \{tokens[0]}")
}
}
///|
fn parse_struct_decl(
pub fn parse_struct_decl(
tokens : ArrayView[Token],
) -> (StructDef, ArrayView[Token]) raise ParseError {
guard tokens is [Struct, UpperIdentifier(id), .. rest] else {
raise ParseError("Expected upper case struct name after 'struct'")
raise ParseError(
"Expected upper case struct name after 'struct', got \{tokens[1]}",
)
}
let (user_defined_type, rest) = if rest
is [LBracket, UpperIdentifier(type_), RBracket, .. r] {
@@ -189,7 +194,7 @@ fn parse_struct_decl(
(None, rest)
}
guard rest is [LCurlyBracket, .. rest] else {
raise ParseError("Expected '{' after struct name")
raise ParseError("Expected '{' after struct name, got \{rest[0]}")
}
let fields = []
loop rest {
@@ -200,19 +205,25 @@ fn parse_struct_decl(
match r {
[Semicolon, .. r] => continue r
[RCurlyBracket, ..] => continue r
_ => raise ParseError("Expected ';' or '}' after struct field")
_ =>
raise ParseError(
"Expected ';' or '}' after struct field, got \{r[0]}",
)
}
}
_ => raise ParseError("Unexpected token in struct field list")
rest =>
raise ParseError("Unexpected token in struct field list, got \{rest[0]}")
}
}
///|
fn parse_enum_decl(
pub fn parse_enum_decl(
tokens : ArrayView[Token],
) -> (EnumDef, ArrayView[Token]) raise ParseError {
guard tokens is [Enum, UpperIdentifier(id), .. rest] else {
raise ParseError("Expected upper case enum name after 'enum'")
raise ParseError(
"Expected upper case enum name after 'enum', got \{tokens[1]}",
)
}
let (user_defined_type, rest) = if rest
is [LBracket, UpperIdentifier(type_), RBracket, .. r] {
@@ -221,7 +232,7 @@ fn parse_enum_decl(
(None, rest)
}
guard rest is [LCurlyBracket, .. rest] else {
raise ParseError("Expected '{' after enum name")
raise ParseError("Expected '{' after enum name, got \{rest[0]}")
}
let variants = []
loop rest {
@@ -239,7 +250,7 @@ fn parse_enum_decl(
[RParen, ..] => continue r
_ =>
raise ParseError(
"Expected ',' or ')' in enum variant type list",
"Expected ',' or ')' in enum variant type list, got \{r[0]}",
)
}
}
@@ -251,15 +262,19 @@ fn parse_enum_decl(
match r {
[Semicolon, .. r] => continue r
[RCurlyBracket, ..] => continue r
_ => raise ParseError("Expected ';' or '}' after enum variant")
_ =>
raise ParseError(
"Expected ';' or '}' after enum variant, got \{r[0]}",
)
}
}
_ => raise ParseError("Unexpected token in enum variant list")
rest =>
raise ParseError("Unexpected token in enum variant list, got \{rest[0]}")
}
}
///|
fn parse_optional_type_annotation(
pub fn parse_optional_type_annotation(
tokens : ArrayView[Token],
) -> (Type?, ArrayView[Token]) raise ParseError {
if tokens is [Colon, .. r] {
@@ -271,18 +286,18 @@ fn parse_optional_type_annotation(
}
///|
fn parse_let_stmt_type_expr(
pub fn parse_let_stmt_type_expr(
tokens : ArrayView[Token],
) -> (Type?, Expr, ArrayView[Token]) raise ParseError {
let (type_, rest) = parse_optional_type_annotation(tokens)
guard rest is [Assign, .. rest] else {
raise ParseError(
"Expected '=' after identifier or type annotation in let declaration",
"Expected '=' after identifier or type annotation in let declaration, got \{rest[0]}",
)
}
let (expr, rest) = parse_expr(rest)
guard rest is [Semicolon, .. rest] else {
raise ParseError("Expected ';' after let statement")
raise ParseError("Expected ';' after let statement, got \{rest[0]}")
}
(type_, expr, rest)
}
@@ -291,7 +306,7 @@ fn parse_let_stmt_type_expr(
/// Returns `(Stmt, is_end_expr, rest_tokens)`.
/// Semicolon is consumed for statements.
/// End expr is wrapped in `Stmt::Expr`.
fn parse_stmt_or_expr_end(
pub fn parse_stmt_or_expr_end(
tokens : ArrayView[Token],
) -> (Stmt, Bool, ArrayView[Token]) raise ParseError {
match tokens {
@@ -306,19 +321,25 @@ fn parse_stmt_or_expr_end(
match binding {
LowerIdentifier(name) | UpperIdentifier(name) => Identifier(name)
Wildcard => Wildcard
_ => raise ParseError("Unreachable")
_ => panic()
},
)
match r {
[Comma, .. r] => continue r
[RParen, ..] => continue r
_ => raise ParseError("Expected ',' or ')' after let binding")
_ =>
raise ParseError(
"Expected ',' or ')' after let binding, got \{r[0]}",
)
}
}
_ => raise ParseError("Unexpected token in let tuple binding list")
rest =>
raise ParseError(
"Unexpected token in let tuple binding list, got \{rest[0]}",
)
}
let (type_, expr, rest) = parse_let_stmt_type_expr(rest)
(LetTuple(bindings, type_, expr), false, rest)
(Let(Tuple(bindings), type_, expr), false, rest)
}
// let_mut_stmt
@@ -332,7 +353,7 @@ fn parse_stmt_or_expr_end(
let binding : Binding = match b {
LowerIdentifier(name) | UpperIdentifier(name) => Identifier(name)
Wildcard => Wildcard
_ => raise ParseError("Unreachable")
_ => panic()
}
let (type_, expr, rest) = parse_let_stmt_type_expr(rest)
(Let(binding, type_, expr), false, rest)
@@ -350,10 +371,15 @@ fn parse_stmt_or_expr_end(
[Comma, .. r] => continue r
[RParen, ..] => continue r
_ =>
raise ParseError("Expected ',' or ')' after function parameter")
raise ParseError(
"Expected ',' or ')' after function parameter, got \{r[0]}",
)
}
}
_ => raise ParseError("Unexpected token in function parameter list")
rest =>
raise ParseError(
"Unexpected token in function parameter list, got \{rest[0]}",
)
}
let (return_type, rest) = if rest is [Arrow, .. r] {
let (t, r) = parse_type(r)
@@ -374,7 +400,7 @@ fn parse_stmt_or_expr_end(
[Return, .. rest] => {
let (expr, rest) = parse_expr(rest)
guard rest is [Semicolon, .. rest] else {
raise ParseError("Expected ';' after return statement")
raise ParseError("Expected ';' after return statement, got \{rest[0]}")
}
(Return(expr), false, rest)
}
@@ -383,13 +409,19 @@ fn parse_stmt_or_expr_end(
[While, .. rest] => {
let (cond_expr, rest) = parse_expr(rest)
let stmts = []
guard rest is [LCurlyBracket, .. rest] else {
raise ParseError("Expected '{' after while condition, got \{rest[0]}")
}
loop rest {
[RCurlyBracket, .. r] => (While(cond_expr, stmts), false, r)
[RCurlyBracket, .. r] => {
stmts.push(Expr(Literal(Unit)))
(While(cond_expr, stmts), false, r)
}
r => {
let (stmt, is_end_expr, rest) = parse_stmt_or_expr_end(r)
if is_end_expr {
raise ParseError(
"Unexpected end expression in while statement body",
"Unexpected end expression \{stmt} in while statement body",
)
}
stmts.push(stmt)
@@ -410,28 +442,35 @@ fn parse_stmt_or_expr_end(
_ => false
}
guard valid_left_value else {
raise ParseError("Invalid left-hand side in assignment")
raise ParseError("Invalid left-hand side in assignment: \{expr}")
}
let (rhs_expr, r) = parse_expr(r)
guard r is [Semicolon, .. r] else {
raise ParseError("Expected ';' after assignment statement")
raise ParseError(
"Expected ';' after assignment statement, got \{r[0]}",
)
}
(Assign(expr, rhs_expr), false, r)
}
[Semicolon, .. r] => (Expr(expr), false, r)
[RCurlyBracket, ..] => (Expr(expr), true, rest)
_ => raise ParseError("Expected ';' or '}' after expression statement")
rest =>
raise ParseError(
"Expected ';' or '}' after expression statement, got \{rest[0]}",
)
}
}
}
}
///|
fn parse_block_expr(
pub fn parse_block_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
guard tokens is [LCurlyBracket, .. rest] else {
raise ParseError("Expected '{' at start of block expression")
raise ParseError(
"Expected '{' at start of block expression, got \{tokens[0]}",
)
}
let stmts = []
let rest = loop rest {
@@ -440,7 +479,9 @@ fn parse_block_expr(
stmts.push(stmt)
if is_end_expr {
guard r is [RCurlyBracket, .. r] else {
raise ParseError("Expected '}' at end of block expression")
raise ParseError(
"Expected '}' at end of block expression, got \{r[0]}",
)
}
break r
} else if r is [RCurlyBracket, .. r] {
@@ -454,11 +495,13 @@ fn parse_block_expr(
}
///|
fn parse_if_expr(
pub fn parse_if_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
guard tokens is [If, .. rest] else {
raise ParseError("Expected 'if' at start of if expression")
raise ParseError(
"Expected 'if' at start of if expression, got \{tokens[0]}",
)
}
let (cond, rest) = parse_expr(rest)
let (then_branch, rest) = parse_block_expr(rest)
@@ -477,7 +520,7 @@ fn parse_if_expr(
}
///|
fn parse_value_level_expr(
pub fn parse_value_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
match tokens {
@@ -486,11 +529,15 @@ fn parse_value_level_expr(
[Array, DoubleColon, LowerIdentifier("make"), LParen, .. rest] => {
let (size_expr, rest) = parse_expr(rest)
guard rest is [Comma, .. rest] else {
raise ParseError("Expected ',' after size expression in array make")
raise ParseError(
"Expected ',' after size expression in array make, got \{rest[0]}",
)
}
let (init_expr, rest) = parse_expr(rest)
guard rest is [RParen, .. rest] else {
raise ParseError("Expected ')' after init expression in array make")
raise ParseError(
"Expected ')' after init expression in array make, got \{rest[0]}",
)
}
(ArrayMake(size_expr, init_expr), rest)
}
@@ -506,10 +553,16 @@ fn parse_value_level_expr(
match r {
[Comma, .. r] => continue r
[RCurlyBracket, ..] => continue r
_ => raise ParseError("Expected ',' or '}' after struct field")
_ =>
raise ParseError(
"Expected ',' or '}' after struct field, got \{r[0]}",
)
}
}
_ => raise ParseError("Unexpected token in struct construction")
rest =>
raise ParseError(
"Unexpected token in struct construction, got \{rest[0]}",
)
}
}
@@ -547,8 +600,10 @@ fn parse_value_level_expr(
} else {
(Tuple(exprs), r)
}
_ =>
raise ParseError("Expected ',' or ')' in tuple or grouped expression")
r =>
raise ParseError(
"Expected ',' or ')' in tuple or grouped expression, got \{r[0]}",
)
}
}
@@ -562,10 +617,10 @@ fn parse_value_level_expr(
elements.push(element)
match r {
[Comma, .. r] => continue r
[RBracket, ..] => (Array(elements), r)
[RBracket, ..] => continue r
_ =>
raise ParseError(
"Expected ',' or ']' in array literal expression",
"Expected ',' or ']' in array literal expression, got \{r[0]}",
)
}
}
@@ -578,12 +633,13 @@ fn parse_value_level_expr(
// identifier_expr
[LowerIdentifier(name) | UpperIdentifier(name), .. rest] =>
(Identifier(name), rest)
_ => raise ParseError("Unsupported value level expression")
tokens =>
raise ParseError("Unsupported value level expression: \{tokens[:3]}")
}
}
///|
fn parse_get_or_apply_level_expr(
pub fn parse_get_or_apply_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
let (value, rest) = parse_value_level_expr(tokens)
@@ -592,7 +648,7 @@ fn parse_get_or_apply_level_expr(
[LBracket, .. r] => {
let (index, r) = parse_expr(r)
guard r is [RBracket, .. r] else {
raise ParseError("Expected ']' after index expression")
raise ParseError("Expected ']' after index expression, got \{r[0]}")
}
result = IndexAccess(result, index)
continue r
@@ -606,8 +662,11 @@ fn parse_get_or_apply_level_expr(
args.push(arg)
match r {
[Comma, .. r] => continue r
[RParen, ..] => break r
_ => raise ParseError("Expected ',' or ')' in function call args")
[RParen, ..] => continue r
_ =>
raise ParseError(
"Expected ',' or ')' in function call args, got \{r[0]}",
)
}
}
}
@@ -623,7 +682,7 @@ fn parse_get_or_apply_level_expr(
}
///|
fn parse_if_level_expr(
pub fn parse_if_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
match tokens {
@@ -634,7 +693,7 @@ fn parse_if_level_expr(
}
///|
fn parse_mul_div_level_expr(
pub fn parse_mul_div_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
let (first, rest) = parse_if_level_expr(tokens)
@@ -660,7 +719,7 @@ fn parse_mul_div_level_expr(
}
///|
fn parse_add_sub_level_expr(
pub fn parse_add_sub_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
let (first, rest) = parse_mul_div_level_expr(tokens)
@@ -681,7 +740,7 @@ fn parse_add_sub_level_expr(
}
///|
fn parse_compare_level_expr(
pub fn parse_compare_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
let (first, rest) = parse_add_sub_level_expr(tokens)
@@ -695,7 +754,7 @@ fn parse_compare_level_expr(
}
///|
fn parse_and_level_expr(
pub fn parse_and_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
let (first, rest) = parse_compare_level_expr(tokens)
@@ -711,7 +770,8 @@ fn parse_and_level_expr(
}
///|
fn parse_or_level_expr(
#alias(parse_expr)
pub fn parse_or_level_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
let (first, rest) = parse_and_level_expr(tokens)
@@ -726,13 +786,6 @@ fn parse_or_level_expr(
}
}
///|
fn parse_expr(
tokens : ArrayView[Token],
) -> (Expr, ArrayView[Token]) raise ParseError {
parse_or_level_expr(tokens)
}
///|
pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError {
let top_lets = Map::new()
@@ -745,13 +798,15 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError {
let (type_, rest) = parse_optional_type_annotation(rest)
guard rest is [Assign, .. rest] else {
raise ParseError(
"Expected '=' after identifier or type annotation in let declaration",
"Expected '=' after identifier or type annotation in let declaration, got \{rest[0]}",
)
}
let (expr, rest) = parse_expr(rest)
top_lets[id] = { id, type_, expr }
guard rest is [Semicolon, .. rest] else {
raise ParseError("Expected ';' after top let declaration")
raise ParseError(
"Expected ';' after top let declaration, got \{rest[0]}",
)
}
continue rest
}
@@ -782,9 +837,9 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError {
}
[UpperIdentifier(id) | LowerIdentifier(id), LParen, .. r] =>
(None, id, r)
_ =>
rest =>
raise ParseError(
"Expected function name (with optional type) after 'fn'",
"Expected function name (with optional type) after 'fn', got \{rest[0]}",
)
}
let params = []
@@ -797,13 +852,20 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError {
[Comma, .. r] => continue r
[RParen, ..] => continue r
_ =>
raise ParseError("Expected ',' or ')' after function parameter")
raise ParseError(
"Expected ',' or ')' after function parameter, got \{r[0]}",
)
}
}
_ => raise ParseError("Unexpected token in function parameter list")
rest =>
raise ParseError(
"Unexpected token in function parameter list, got \{rest[0]}",
)
}
guard rest is [Arrow, .. rest] else {
raise ParseError("Expected '->' after function parameter list")
raise ParseError(
"Expected '->' after function parameter list, got \{rest[0]}",
)
}
let (return_type, rest) = parse_type(rest)
guard parse_block_expr(rest) is (Block(body), rest)

View File

@@ -209,7 +209,7 @@ test "parse_get_or_apply_level_expr" {
EOF,
]),
content=(
#|(FunctionCall(Identifier("f"), [Literal(Int(1))]), [RParen, EOF])
#|(FunctionCall(Identifier("f"), [Literal(Int(1))]), [EOF])
),
)
}

View File

@@ -257,7 +257,7 @@ pub fn tokenize(input : String) -> Array[Token] raise TokenizeError {
continue rest
}
[] => tokens.push(EOF)
[c, ..] => raise TokenizeError("Unexpected character: \{c}")
str => raise TokenizeError("Unexpected token: \{try? str[:20]} ...")
}
tokens
}

177
src/typecheck/expr.mbt Normal file
View File

@@ -0,0 +1,177 @@
///|
pub(all) struct Expr {
kind : ExprKind
ty : TypeKind
} derive(Show)
///|
pub(all) enum ExprKind {
ApplyExpr(ApplyExpr)
BlockExpr(BlockExpr)
NotExpr(Expr)
NegExpr(Expr)
Compare(CompareOperator, Expr, Expr)
AddSub(AddSubOp, Expr, Expr)
MulDivRem(MulDivRemOp, Expr, Expr)
And(Expr, Expr)
Or(Expr, Expr)
IfExpr(IfExpr)
} derive(Show)
///|
pub(all) enum CompareOperator {
Equal
NotEqual
GreaterEqual
LessEqual
Greater
Less
} derive(Show)
///|
pub(all) enum AddSubOp {
Add
Sub
} derive(Show)
///|
pub(all) enum MulDivRemOp {
Mul
Div
Rem
} derive(Show)
///|
pub fn Context::check_expr(
self : Self,
expr : @parser.Expr,
) -> Expr raise TypeCheckError {
match expr {
Literal(_)
| Identifier(_)
| Tuple(_)
| Array(_)
| ArrayMake(_)
| StructConstruct(_)
| IndexAccess(_)
| FieldAccess(_)
| FunctionCall(_) => {
let apply_expr = self.check_apply_expr(expr)
{ kind: ApplyExpr(apply_expr), ty: apply_expr.ty }
}
Not(inner_expr) => {
let inner_expr = self.check_expr(inner_expr)
guard inner_expr.ty is Bool else {
raise TypeCheckError("Operand of '!' must be Bool.")
}
{ kind: NotExpr(inner_expr), ty: Bool }
}
Neg(inner_expr) => {
let inner_expr = self.check_expr(inner_expr)
match inner_expr.ty {
Int => { kind: NegExpr(inner_expr), ty: Int }
Double => { kind: NegExpr(inner_expr), ty: Double }
_ => raise TypeCheckError("Operand of unary '-' must be Int or Double.")
}
}
Compare(op, left, right) => {
let left = self.check_expr(left)
let right = self.check_expr(right)
if !self.is_type_compatible(left.ty, right.ty) {
raise TypeCheckError(
"Operands of comparison must be of compatible types.",
)
}
guard self.deref_type_var(left.ty)
is (Int | Double | Bool | Unit | TypeVar(_)) else {
// XXX: Constraint TypeVar to these types
raise TypeCheckError(
"Operands of comparison must be Int, Double, Bool, or Unit.",
)
}
{ kind: Compare(compare_op_map(op), left, right), ty: Bool }
}
AddSub(op, left, right) => {
let left = self.check_expr(left)
let right = self.check_expr(right)
if !self.is_type_compatible(left.ty, right.ty) {
raise TypeCheckError("Operands of '+'/'-' must be of compatible types.")
}
match self.deref_type_var(left.ty) {
Int => { kind: AddSub(add_sub_op_map(op), left, right), ty: Int }
Double => { kind: AddSub(add_sub_op_map(op), left, right), ty: Double }
_ => raise TypeCheckError("Operands of '+'/'-' must be Int or Double.")
}
}
MulDivRem(op, left, right) => {
let left = self.check_expr(left)
let right = self.check_expr(right)
if !self.is_type_compatible(left.ty, right.ty) {
raise TypeCheckError(
"Operands of '*','/','%' must be of compatible types.",
)
}
match self.deref_type_var(left.ty) {
Int => { kind: MulDivRem(mul_div_rem_op_map(op), left, right), ty: Int }
Double =>
{ kind: MulDivRem(mul_div_rem_op_map(op), left, right), ty: Double }
_ =>
raise TypeCheckError("Operands of '*','/','%' must be Int or Double.")
}
}
And(left, right) => {
let left = self.check_expr(left)
let right = self.check_expr(right)
guard left.ty is Bool && right.ty is Bool else {
raise TypeCheckError("Operands of '&&' must be Bool.")
}
{ kind: And(left, right), ty: Bool }
}
Or(left, right) => {
let left = self.check_expr(left)
let right = self.check_expr(right)
guard left.ty is Bool && right.ty is Bool else {
raise TypeCheckError("Operands of '||' must be Bool.")
}
{ kind: Or(left, right), ty: Bool }
}
If(_) => {
let if_expr = self.check_if_expr(expr)
{ kind: IfExpr(if_expr), ty: if_expr.ty }
}
Block(_) => {
let block_expr = self.check_block_expr(expr)
{ kind: BlockExpr(block_expr), ty: block_expr.ty }
}
EnumConstruct(_) | Match(_) => ...
}
}
///|
fn compare_op_map(op : @parser.CompareOperator) -> CompareOperator {
match op {
Equal => Equal
NotEqual => NotEqual
GreaterEqual => GreaterEqual
LessEqual => LessEqual
Greater => Greater
Less => Less
}
}
///|
fn add_sub_op_map(op : @parser.AddSubOp) -> AddSubOp {
match op {
Add => Add
Sub => Sub
}
}
///|
fn mul_div_rem_op_map(op : @parser.MulDivRemOp) -> MulDivRemOp {
match op {
Mul => Mul
Div => Div
Rem => Rem
}
}

View File

@@ -0,0 +1,95 @@
///|
pub(all) struct ApplyExpr {
kind : ApplyExprKind
ty : TypeKind
} derive(Show)
///|
pub(all) enum ApplyExprKind {
AtomExpr(AtomExpr)
ArrayAccess(ApplyExpr, Expr)
FieldAccess(ApplyExpr, String)
Call(ApplyExpr, Array[Expr])
} derive(Show)
///|
pub fn Context::check_apply_expr(
self : Self,
apply_expr : @parser.Expr,
) -> ApplyExpr raise TypeCheckError {
match apply_expr {
Literal(_)
| Identifier(_)
| Tuple(_)
| Array(_)
| ArrayMake(_)
| StructConstruct(_) => {
let atom_expr = self.check_atom_expr(apply_expr)
{ kind: AtomExpr(atom_expr), ty: atom_expr.ty }
}
IndexAccess(base_expr, index_expr) => {
let base_typed = self.check_apply_expr(base_expr)
guard base_typed.ty is Array(elem_type) else {
raise TypeCheckError("Attempting to index a non-array type.")
}
let index_typed = self.check_expr(index_expr)
guard index_typed.ty is Int else {
raise TypeCheckError("Array index must be of type Int.")
}
{ kind: ArrayAccess(base_typed, index_typed), ty: elem_type }
}
FieldAccess(base_expr, field_name) => {
let base_typed = self.check_apply_expr(base_expr)
match base_typed.ty {
Array(t) =>
{
kind: FieldAccess(base_typed, field_name),
ty: match field_name {
"length" => Function([], Int)
"push" => Function([t], Unit)
_ => raise TypeCheckError("Array has no method '\{field_name}'.")
},
}
Struct(struct_name) => {
let struct_def = self.struct_defs.get(struct_name).unwrap()
let field_type = struct_def
.get_field_type(field_name)
.or_error(
TypeCheckError(
"Struct '\{struct_name}' has no field '\{field_name}'.",
),
)
{ kind: FieldAccess(base_typed, field_name), ty: field_type.kind }
}
t => raise TypeCheckError("Attempting to access field of type '\{t}'.")
}
}
FunctionCall(callee_expr, arg_exprs) => {
let callee_typed = self.check_apply_expr(callee_expr)
guard callee_typed.ty is Function(param_types, return_type) else {
raise TypeCheckError("Attempting to call a non-function type.")
}
let args_count = arg_exprs.length()
guard param_types.length() == args_count else {
raise TypeCheckError(
"Function called with incorrect number of arguments.",
)
}
let args_typed = []
for i in 0..<args_count {
let arg_typed = self.check_expr(arg_exprs[i])
guard self.is_type_compatible(param_types[i], arg_typed.ty) else {
raise TypeCheckError(
"Type of argument \{i + 1} does not match function parameter type.",
)
}
args_typed.push(arg_typed)
}
{
kind: Call(callee_typed, args_typed),
ty: self.deref_type_var(return_type),
}
}
e => raise TypeCheckError("Unsupported apply expression \{e}")
}
}

106
src/typecheck/expr_atom.mbt Normal file
View File

@@ -0,0 +1,106 @@
///|
pub(all) struct AtomExpr {
kind : AtomExprKind
ty : TypeKind
} derive(Show)
///|
pub(all) enum AtomExprKind {
Int(Int) // 1, 42, etc
Double(Double) // 1.0, 3.14, etc
Bool(Bool) // true | false
Unit // ()
Ident(String) // var
Tuple(Array[Expr]) // (expr, expr, ...)
Array(Array[Expr]) // [expr, expr, ...]
ArrayMake(Expr, Expr) // Array::make(size, init)
StructConstruct(StructConstructExpr) // StructName::{ field: expr, ... }
} derive(Show)
///|
pub(all) struct StructConstructExpr {
name : String
fields : Array[(String, Expr)]
} derive(Show)
///|
pub fn Context::check_atom_expr(
self : Self,
atom_expr : @parser.Expr,
) -> AtomExpr raise TypeCheckError {
match atom_expr {
Literal(Int(v)) => { kind: Int(v), ty: Int }
Literal(Double(v)) => { kind: Double(v), ty: Double }
Literal(Bool(v)) => { kind: Bool(v), ty: Bool }
Literal(Unit) => { kind: Unit, ty: Unit }
Identifier(name) => {
let var_type = self.lookup_type(name)
match var_type {
Some(ty) => { kind: Ident(name), ty: ty.kind }
None => raise TypeCheckError("Undefined variable: \{name}")
}
}
Tuple(elems) => {
let exprs = []
let types = []
for expr in elems {
let typed_expr = self.check_expr(expr)
exprs.push(typed_expr)
types.push(typed_expr.ty)
}
{ kind: Tuple(exprs), ty: Tuple(types) }
}
Array([]) => { kind: Array([]), ty: Array(self.new_type_var()) }
Array([first_elem, .. other_elems]) => {
let first_elem = self.check_expr(first_elem)
let elem_type = first_elem.ty
let exprs = [first_elem]
for expr in other_elems {
let typed_expr = self.check_expr(expr)
if !self.is_type_compatible(elem_type, typed_expr.ty) {
raise TypeCheckError(
"Incompatible types in array expression: \{typed_expr.ty} is not \{elem_type}",
)
}
exprs.push(typed_expr)
}
{ kind: Array(exprs), ty: Array(elem_type) }
}
ArrayMake(size_expr, init_expr) => {
let size_typed = self.check_expr(size_expr)
guard size_typed.ty is Int else {
raise TypeCheckError("Array size must be of type Int.")
}
let init_typed = self.check_expr(init_expr)
{ kind: ArrayMake(size_typed, init_typed), ty: Array(init_typed.ty) }
}
StructConstruct(name, fields) => {
let def = self.struct_defs
.get(name)
.or_error(TypeCheckError("Undefined struct: \{name}"))
let field_exprs = []
for field in fields {
let (field_name, field_expr) = field
let expected_type = def
.get_field_type(field_name)
.or_error(
TypeCheckError(
"Struct '\{name}' has no field named '\{field_name}'.",
),
)
let typed_expr = self.check_expr(field_expr)
if !self.is_type_compatible(expected_type.kind, typed_expr.ty) {
raise TypeCheckError(
"Type mismatch for field '\{field_name}' in struct '\{name}'.",
)
}
field_exprs.push((field_name, typed_expr))
}
guard field_exprs.length() == def.fields.length() else {
raise TypeCheckError("Struct '\{name}' construction missing fields.")
}
{ kind: StructConstruct({ name, fields: field_exprs }), ty: Struct(name) }
}
e => raise TypeCheckError("Unsupported atom expression \{e}")
}
}

View File

@@ -0,0 +1,23 @@
///|
pub(all) struct BlockExpr {
stmts : Array[Stmt]
ty : TypeKind
} derive(Show)
///|
pub fn Context::check_block_expr(
self : Context,
block_expr : @parser.Expr,
) -> BlockExpr raise TypeCheckError {
guard block_expr is Block(stmts) else {
raise TypeCheckError("Expected block expression")
}
self.enter_scope()
let checked_stmts = stmts.map(stmt => self.check_stmt(stmt))
self.exit_scope()
let stmts_count = checked_stmts.length()
guard stmts_count > 0 && checked_stmts[stmts_count - 1].kind is ExprStmt(expr) else {
return { stmts: checked_stmts, ty: Unit }
}
{ stmts: checked_stmts, ty: expr.ty }
}

45
src/typecheck/expr_if.mbt Normal file
View File

@@ -0,0 +1,45 @@
///|
pub(all) struct IfExpr {
cond : Expr
then_block : BlockExpr
else_block : Expr?
ty : TypeKind
} derive(Show)
///|
pub fn Context::check_if_expr(
self : Context,
if_expr : @parser.Expr,
) -> IfExpr raise TypeCheckError {
guard if_expr is If(cond, then_block, else_block) else {
raise TypeCheckError("Expected if expression")
}
let cond = self.check_expr(cond)
guard self.is_type_compatible(cond.ty, Bool) else {
raise TypeCheckError("if condition must be of type Bool, got \{cond.ty}")
}
guard self.check_expr(then_block) is { kind: BlockExpr(then_block), .. } else {
raise TypeCheckError("then block must be a block expression")
}
let else_block = match else_block {
Some(else_expr) => {
let checked_else = self.check_expr(else_expr)
guard self.is_type_compatible(checked_else.ty, then_block.ty) else {
raise TypeCheckError(
"else block type \{checked_else.ty} does not match then block type \{then_block.ty}",
)
}
Some(checked_else)
}
None => {
guard self.is_type_compatible(then_block.ty, Unit) else {
raise TypeCheckError(
"if expression without else block must have then block of type Unit, got \{then_block.ty}",
)
}
None
}
}
let ty = self.deref_type_var(then_block.ty)
{ cond, then_block, else_block, ty }
}

View File

@@ -0,0 +1,5 @@
{
"import": [
"Lil-Ran/lilunar/parser"
]
}

153
src/typecheck/program.mbt Normal file
View File

@@ -0,0 +1,153 @@
///|
pub(all) suberror TypeCheckError String derive(Show)
///|
pub(all) struct Env {
local_ : Map[String, Type]
parent : Env?
}
///|
pub fn Env::new(parent? : Env? = None) -> Env {
{ local_: Map::new(), parent }
}
///|
pub fn Env::get(self : Env, name : String) -> Type? {
match self.local_.get(name) {
Some(t) => Some(t)
None =>
match self.parent {
Some(p) => p.get(name)
None => None
}
}
}
///|
pub fn Env::set(self : Env, name : String, t : Type) -> Unit {
self.local_.set(name, t)
}
///|
pub(all) struct Context {
mut type_env : Env
type_vars : Map[Int, TypeKind]
struct_defs : Map[String, StructDef]
func_types : Map[String, TypeKind]
mut current_func_ret_ty : TypeKind?
}
///|
pub fn Context::new() -> Context {
{
type_env: Env::new(),
type_vars: Map::new(),
struct_defs: Map::new(),
func_types: Map::new(),
current_func_ret_ty: None,
}
}
///|
pub fn Context::new_type_var(self : Context) -> TypeKind {
let type_var_id = self.type_vars.length()
self.type_vars.set(type_var_id, TypeVar(type_var_id))
TypeVar(type_var_id)
}
///|
pub fn Context::lookup_type(self : Context, name : String) -> Type? {
let looked = self.type_env.get(name)
if looked is None {
return None
}
let { kind, mutable } = looked.unwrap()
let deref_kind = self.deref_type_var(kind)
Some({ kind: deref_kind, mutable })
}
///|
pub fn Context::enter_scope(self : Self) -> Unit {
let sub_env = Env::new(parent=Some(self.type_env))
self.type_env = sub_env
}
///|
pub fn Context::exit_scope(self : Context) -> Unit {
self.type_env = match self.type_env.parent {
Some(p) => p
None => self.type_env
}
}
///|
pub fn Context::set_current_func_ret_ty(self : Context, ty : TypeKind) -> Unit {
self.current_func_ret_ty = Some(ty)
}
///|
pub(all) struct Program {
top_lets : Map[String, TopLet]
top_functions : Map[String, TopFunction]
struct_defs : Map[String, StructDef]
} derive(Show)
///|
pub fn Context::check_program(
self : Context,
program : @parser.Program,
) -> Program raise TypeCheckError {
self.collect_struct_names(program)
self.collect_function_types(program)
let struct_defs : Map[String, StructDef] = Map::new()
for name, struct_def in program.struct_defs {
let checked_struct_def = self.check_struct_def(struct_def)
struct_defs.set(name, checked_struct_def)
self.struct_defs.set(name, checked_struct_def)
}
let top_functions : Map[String, TopFunction] = Map::new()
for name, func in program.top_functions {
let checked_func = self.check_top_function(func)
top_functions.set(name, checked_func)
self.func_types.set(
name,
Function(checked_func.param_list.map(p => p.ty), checked_func.ret_ty),
)
}
let top_lets : Map[String, TopLet] = Map::new()
for name, let_def in program.top_lets {
let checked_let = self.check_top_let(let_def)
top_lets.set(name, checked_let)
self.type_env.set(name, checked_let.ty)
}
{ top_lets, top_functions, struct_defs }
}
///|
pub fn Context::collect_struct_names(
self : Context,
program : @parser.Program,
) -> Unit raise TypeCheckError {
for name, _ in program.struct_defs {
if self.struct_defs.contains(name) {
raise TypeCheckError("Duplicate struct definition: \{name}")
}
self.struct_defs.set(name, { name, fields: [] })
}
}
///|
pub fn Context::collect_function_types(
self : Context,
program : @parser.Program,
) -> Unit {
for name, _ in program.top_functions {
let func_type = self.new_type_var()
self.func_types.set(name, func_type)
}
for name, _ in program.top_lets {
let let_type = self.new_type_var()
self.type_env.set(name, { kind: let_type, mutable: false })
}
}

43
src/typecheck/stmt.mbt Normal file
View File

@@ -0,0 +1,43 @@
///|
pub(all) struct Stmt {
kind : StmtKind
} derive(Show)
///|
pub(all) enum StmtKind {
LetStmt(LetStmt)
LetMutStmt(LetMutStmt)
AssignStmt(AssignStmt)
WhileStmt(WhileStmt)
ExprStmt(Expr)
ReturnStmt(Expr)
LocalFunction(LocalFunction)
} derive(Show)
///|
pub fn Context::check_stmt(
self : Context,
stmt : @parser.Stmt,
) -> Stmt raise TypeCheckError {
match stmt {
Return(expr) => {
guard self.current_func_ret_ty is Some(expected) else {
raise TypeCheckError("return statement outside of function")
}
let actual = self.check_expr(expr)
if !self.is_type_compatible(actual.ty, expected) {
raise TypeCheckError(
"return type mismatch: expected \{expected}, got \{actual.ty}",
)
}
{ kind: ReturnStmt(actual) }
}
Expr(expr) => { kind: ExprStmt(self.check_expr(expr)) }
While(_) => { kind: WhileStmt(self.check_while_stmt(stmt)) }
Assign(_) => { kind: AssignStmt(self.check_assign_stmt(stmt)) }
LetMut(_) => { kind: LetMutStmt(self.check_let_mut_stmt(stmt)) }
Let(_) => { kind: LetStmt(self.check_let_stmt(stmt)) }
LocalFunction(func) =>
{ kind: LocalFunction(self.check_local_function(func)) }
}
}

View File

@@ -0,0 +1,84 @@
///|
pub(all) struct LeftValue {
kind : LeftValueKind
ty : Type
} derive(Show)
///|
pub(all) enum LeftValueKind {
Ident(String)
ArrayAccess(LeftValue, Expr)
FieldAccess(LeftValue, String)
} derive(Show)
///|
pub(all) struct AssignStmt {
left_value : LeftValue
expr : Expr
} derive(Show)
///|
pub fn Context::check_left_value(
self : Context,
lv : @parser.Expr,
) -> LeftValue raise TypeCheckError {
match lv {
Identifier(name) => {
guard self.lookup_type(name) is Some(var_info) else {
raise TypeCheckError("Undefined variable '\{name}'.")
}
{ kind: Ident(name), ty: var_info }
}
IndexAccess(base, index_expr) => {
let base_typed = self.check_left_value(base)
let index_typed = self.check_expr(index_expr)
guard index_typed.ty is Int else {
raise TypeCheckError("Array index must be of type Int.")
}
guard base_typed.ty.kind is Array(elem_type) else {
raise TypeCheckError("Base of array access is not an array.")
}
{
kind: ArrayAccess(base_typed, index_typed),
ty: { kind: elem_type, mutable: true }, // MiniMoonBit Array elements are always mutable
}
}
FieldAccess(base, field_name) => {
let base_typed = self.check_left_value(base)
guard base_typed.ty.kind is Struct(name) else {
raise TypeCheckError("Base of field access is not a struct.")
}
let field_type = self.struct_defs
.get(name)
.or_error(TypeCheckError("Undefined struct: \{name}"))
.get_field_type(field_name)
.or_error(TypeCheckError("Struct has no field named '\{field_name}'."))
{
kind: FieldAccess(base_typed, field_name),
ty: { kind: field_type.kind, mutable: base_typed.ty.mutable },
}
}
lv => raise TypeCheckError("Invalid left value expression: \{lv}")
}
}
///|
pub fn Context::check_assign_stmt(
self : Context,
assign_stmt : @parser.Stmt,
) -> AssignStmt raise TypeCheckError {
guard assign_stmt is Assign(left_value, expr) else {
raise TypeCheckError("Expected AssignStmt.")
}
let left_value_typed = self.check_left_value(left_value)
guard left_value_typed.ty.mutable else {
raise TypeCheckError("Cannot assign to immutable left value \{left_value}.")
}
let expr_typed = self.check_expr(expr)
guard self.is_type_compatible(expr_typed.ty, left_value_typed.ty.kind) else {
raise TypeCheckError(
"Cannot assign value of type \{expr_typed.ty} to variable of type \{left_value_typed.ty}.",
)
}
{ left_value: left_value_typed, expr: expr_typed }
}

View File

@@ -0,0 +1,77 @@
///|
pub(all) struct LetStmt {
pattern : Pattern
ty : TypeKind
expr : Expr
} derive(Show)
///|
pub(all) struct Pattern {
kind : PatternKind
} derive(Show, Eq)
///|
pub(all) enum PatternKind {
Wildcard
Ident(String)
Tuple(Array[Pattern])
} derive(Show, Eq)
///|
pub fn Context::check_let_stmt(
self : Context,
let_stmt : @parser.Stmt,
) -> LetStmt raise TypeCheckError {
guard let_stmt is Let(pattern, ty, expr) else {
raise TypeCheckError("Expected LetStmt.")
}
let expr_typed = self.check_expr(expr)
let mut result_ty = expr_typed.ty
if ty is Some(ty) {
result_ty = self.check_parser_type(ty).kind
guard self.is_type_compatible(result_ty, expr_typed.ty) else {
raise TypeCheckError(
"Type annotation does not match expression type for pattern '\{pattern}'.",
)
}
}
let pattern = parser_pattern_map(pattern)
result_ty = self.deref_type_var(result_ty)
self.set_pattern_types(pattern, result_ty)
{ pattern, ty: result_ty, expr: expr_typed }
}
///|
pub fn Context::set_pattern_types(
self : Context,
pattern : Pattern,
ty : TypeKind,
) -> Unit raise TypeCheckError {
match pattern.kind {
Ident(name) => self.type_env.set(name, { kind: ty, mutable: false })
Wildcard => ()
Tuple(patterns) => {
guard ty is Tuple(elem_types) else {
raise TypeCheckError("Type mismatch for tuple pattern.")
}
guard patterns.length() == elem_types.length() else {
raise TypeCheckError("Tuple pattern length does not match type.")
}
for i in 0..<patterns.length() {
self.set_pattern_types(patterns[i], elem_types[i])
}
}
}
}
///|
pub fn parser_pattern_map(pattern : @parser.Binding) -> Pattern {
match pattern {
Identifier(name) => { kind: Ident(name) }
Wildcard => { kind: Wildcard }
Tuple(bindings) => {
let patterns = bindings.map(binding => parser_pattern_map(binding))
{ kind: Tuple(patterns) }
}
}
}

View File

@@ -0,0 +1,29 @@
///|
pub(all) struct LetMutStmt {
name : String
ty : Type
expr : Expr
} derive(Show)
///|
pub fn Context::check_let_mut_stmt(
self : Context,
stmt : @parser.Stmt,
) -> LetMutStmt raise TypeCheckError {
guard stmt is LetMut(name, ty, expr) else {
raise TypeCheckError("Expected LetMutStmt.")
}
let expr_typed = self.check_expr(expr)
let mut result_ty = expr_typed.ty
if ty is Some(ty) {
result_ty = self.check_parser_type(ty).kind
guard self.is_type_compatible(result_ty, expr_typed.ty) else {
raise TypeCheckError(
"Type annotation does not match expression type for variable '\{name}'.",
)
}
}
result_ty = self.deref_type_var(result_ty)
self.type_env.set(name, { kind: result_ty, mutable: true })
{ name, ty: { kind: result_ty, mutable: true }, expr: expr_typed }
}

View File

@@ -0,0 +1,52 @@
///|
pub(all) struct LocalFunction {
fname : String
param_list : Array[(String, Type)]
ret_ty : Type
body : BlockExpr
} derive(Show)
///|
pub fn Context::check_local_function(
self : Context,
func : @parser.Function,
) -> LocalFunction raise TypeCheckError {
let param_list = func.params.map(param => {
let (param_name, param_type) = param
match param_type {
Some(ty) =>
(param_name, { kind: self.check_parser_type(ty).kind, mutable: false })
None => (param_name, { kind: self.new_type_var(), mutable: false })
}
})
let ret_type = match func.return_type {
Some(ty) => self.check_parser_type(ty).kind
None => self.new_type_var()
}
let func_type = Function(param_list.map(p => p.1.kind), ret_type)
self.type_env.set(func.id, { kind: func_type, mutable: false })
//
self.enter_scope()
self.current_func_ret_ty = Some(ret_type)
//
for param in param_list {
let (param_name, param_type) = param
self.type_env.set(param_name, param_type)
}
//
let checked_body = self.check_block_expr(Block(func.body))
if !self.is_type_compatible(checked_body.ty, ret_type) {
raise TypeCheckError(
"Function '\{func.id}' return type mismatch: expected \{ret_type}, got \{checked_body.ty}",
)
}
//
self.exit_scope()
//
{
fname: func.id,
param_list,
ret_ty: { kind: ret_type, mutable: false },
body: checked_body,
}
}

View File

@@ -0,0 +1,23 @@
///|
pub(all) struct WhileStmt {
cond : Expr
body : BlockExpr
} derive(Show)
///|
pub fn Context::check_while_stmt(
self : Context,
while_stmt : @parser.Stmt,
) -> WhileStmt raise TypeCheckError {
guard while_stmt is While(cond_expr, body_expr) else {
raise TypeCheckError("Expected while statement")
}
let cond = self.check_expr(cond_expr)
guard self.is_type_compatible(cond.ty, Bool) else {
raise TypeCheckError("while condition must be of type Bool, got \{cond.ty}")
}
guard self.check_expr(Block(body_expr)) is { kind: BlockExpr(body_block), .. } else {
raise TypeCheckError("while body must be a block expression")
}
{ cond, body: body_block }
}

View File

@@ -0,0 +1,36 @@
///|
pub(all) struct StructDef {
name : String
fields : Array[StructField]
} derive(Show)
///|
pub(all) struct StructField {
name : String
ty : Type
} derive(Show)
///|
pub fn StructDef::get_field_type(self : Self, field_name : String) -> Type? {
for field in self.fields {
if field.name == field_name {
return Some(field.ty)
}
}
return None
}
///|
pub fn Context::check_struct_def(
self : Self,
struct_def : @parser.StructDef,
) -> StructDef raise TypeCheckError {
{
name: struct_def.id,
fields: struct_def.fields.map(field => {
let (field_name, field_type) = field
let field_type = self.check_parser_type(field_type)
{ name: field_name, ty: field_type }
}),
}
}

View File

@@ -0,0 +1,63 @@
///|
pub(all) struct Param {
name : String
ty : TypeKind
} derive(Show)
///|
pub(all) struct TopFunction {
fname : String
param_list : Array[Param]
ret_ty : TypeKind
body : BlockExpr
} derive(Show)
///|
pub fn Context::check_top_function(
self : Context,
func : @parser.Function,
) -> TopFunction raise TypeCheckError {
// XXX: 目前的泛型不是真正的泛型,只是只能归一成单个类型的类型变量
if func.user_defined_type is Some(UserDefined(udt)) {
self.type_env.set("$Generic$\{udt}", {
kind: self.new_type_var(),
mutable: false,
})
}
//
let param_list : Array[Param] = func.params.map(param => {
let (param_name, param_type) = param
match param_type {
Some(ty) => { name: param_name, ty: self.check_parser_type(ty).kind }
None => { name: param_name, ty: self.new_type_var() }
}
})
let ret_ty = match func.return_type {
Some(ty) => self.check_parser_type(ty).kind
None => self.new_type_var()
}
let func_type = Function(param_list.map(p => p.ty), ret_ty)
self.type_env.set(func.id, { kind: func_type, mutable: false })
//
self.enter_scope()
self.current_func_ret_ty = Some(ret_ty)
//
for param in param_list {
let { name, ty } = param
self.type_env.set(name, { kind: ty, mutable: false })
}
//
let checked_body = self.check_block_expr(Block(func.body))
if !self.is_type_compatible(checked_body.ty, ret_ty) {
raise TypeCheckError(
"Function '\{func.id}' return type mismatch: expected \{ret_ty}, got \{checked_body.ty}",
)
}
//
self.exit_scope()
if func.user_defined_type is Some(UserDefined(udt)) {
self.type_env.local_.remove("$Generic$\{udt}")
}
//
{ fname: func.id, param_list, ret_ty, body: checked_body }
}

27
src/typecheck/top_let.mbt Normal file
View File

@@ -0,0 +1,27 @@
///|
pub(all) struct TopLet {
name : String
ty : Type
expr : Expr
} derive(Show)
///|
pub fn Context::check_top_let(
self : Context,
top_let : @parser.TopLet,
) -> TopLet raise TypeCheckError {
let { id: name, type_: ty, expr } = top_let
let expr_typed = self.check_expr(expr)
let mut result_ty = expr_typed.ty
if ty is Some(ty) {
result_ty = self.check_parser_type(ty).kind
guard self.is_type_compatible(result_ty, expr_typed.ty) else {
raise TypeCheckError(
"Type annotation does not match expression type for variable '\{name}'.",
)
}
}
result_ty = self.deref_type_var(result_ty)
self.type_env.set(name, { kind: result_ty, mutable: false })
{ name, ty: { kind: result_ty, mutable: false }, expr: expr_typed }
}

View File

@@ -0,0 +1,826 @@
///|
test "TypeCheck Normal Type Test" {
let ctx = Context::new()
let code =
#|Unit Int Bool Double
#|(Int, Bool) Array[Int]
#|Array[(Int, Double)]
#|(Int, Bool) -> Double
let tokens = @parser.tokenize(code)
let (t, tok_view) = @parser.parse_type(tokens)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Unit)
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Int)
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Bool)
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Double)
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Tuple([t1, t2]) && t1 is Int && t2 is Bool)
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Array(t1) && t1 is Int)
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(
t.kind is Array(t1) && t1 is Tuple([t2, t3]) && t2 is Int && t3 is Double,
)
let (t, _) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(
t.kind is Function([t1, t2], t3) && t1 is Int && t2 is Bool && t3 is Double,
)
}
///|
test "TypeCheck Defined Type Test" {
let ctx = Context::new()
ctx.struct_defs.set("Point", { name: "Point", fields: [] })
ctx.struct_defs.set("Circle", { name: "Circle", fields: [] })
ctx.struct_defs.set("Rectangle", { name: "Rectangle", fields: [] })
// Note: No Triangle
let code =
#|Point Circle Rectangle Triangle
#|Array[Point] (Point, Circle)
#|(Rectangle, Point) -> Circle
#|Array[Triangle]
let tokens = @parser.tokenize(code)
let (t, tok_view) = @parser.parse_type(tokens)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Struct("Point"))
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Struct("Circle"))
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Struct("Rectangle"))
let (t, tok_view) = @parser.parse_type(tok_view)
let t = try? ctx.check_parser_type(t)
assert_true(t is Err(_)) // Triangle is not defined
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(t.kind is Array(t1) && t1 is Struct("Point"))
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(
t.kind is Tuple([t1, t2]) && t1 is Struct("Point") && t2 is Struct("Circle"),
)
let (t, tok_view) = @parser.parse_type(tok_view)
let t = ctx.check_parser_type(t)
assert_true(
t.kind is Function([t1, t2], t3) &&
t1 is Struct("Rectangle") &&
t2 is Struct("Point") &&
t3 is Struct("Circle"),
)
let (t, _) = @parser.parse_type(tok_view)
let t = try? ctx.check_parser_type(t)
assert_true(t is Err(_)) // Triangle is not defined
}
///|
test "Struct Definition Typecheck" {
let code =
#|struct Point { x: Int; y: Int; }
#|struct Queue { data: Array[Int]; front: Int; back: Int; }
let tokens = @parser.tokenize(code)
let ctx = Context::new()
// Parse
let (struct_def, tok_view) = @parser.parse_struct_decl(tokens[:])
// Typecheck for `struct Point { x: Int; y: Int }`
let struct_def = ctx.check_struct_def(struct_def)
assert_true(struct_def.name is "Point")
assert_true(struct_def.fields.length() is 2)
assert_true(
struct_def.fields is [f1, f2] &&
f1 is { name: "x", ty: { kind: Int, mutable: false } } &&
f2 is { name: "y", ty: { kind: Int, mutable: false } },
)
// Typecheck for `struct Queue { data: Array[Int]; front: Int; back: Int }`
let (struct_def, _) = @parser.parse_struct_decl(tok_view)
let struct_def = ctx.check_struct_def(struct_def)
assert_true(struct_def.name is "Queue")
assert_true(struct_def.fields.length() is 3)
assert_true(
struct_def.fields is [f1, f2, f3] &&
f1 is { name: "data", ty: { kind: Array(Int), mutable: false } } &&
f2 is { name: "front", ty: { kind: Int, mutable: false } } &&
f3 is { name: "back", ty: { kind: Int, mutable: false } },
)
}
///|
test "Type Var Test - 1" {
let ctx = Context::new()
ctx.type_vars.set(0, TypeVar(0))
ctx.type_vars.set(1, TypeVar(1))
ctx.type_vars.set(2, TypeVar(2))
ctx.type_vars.set(3, TypeVar(3))
// TVar(0) = Double
let t1 = ctx.is_type_compatible(TypeVar(0), Double)
assert_true(t1 is true)
assert_true(ctx.type_vars.get(0).unwrap() is Double)
// TVar(1) = TVar(0)
let t2 = ctx.is_type_compatible(TypeVar(1), TypeVar(0))
assert_true(t2 is true)
assert_true(ctx.type_vars.get(1).unwrap() is Double)
// TVar(2) = TVar(3)
let t3 = ctx.is_type_compatible(TypeVar(2), TypeVar(3))
assert_true(t3 is true)
// TVAr(3) = Int, therefore TVar(2) = Int
let t4 = ctx.is_type_compatible(TypeVar(3), Int)
assert_true(t4 is true)
assert_true(ctx.type_vars.get(2).unwrap() is Int)
assert_true(ctx.type_vars.get(3).unwrap() is Int)
let t5 = ctx.is_type_compatible(TypeVar(2), Double)
assert_true(t5 is false)
}
///|
///
/// 0 == 1 == 2 == 3
/// 2 == (Double, Int)
///
/// 0, 1, 2, 3 should all resolve to (Double, Int)
test "Type Var Test - 2" {
let ctx = Context::new()
ctx.type_vars.set(0, TypeVar(0))
ctx.type_vars.set(1, TypeVar(1))
ctx.type_vars.set(2, TypeVar(2))
ctx.type_vars.set(3, TypeVar(3))
//
let _ = ctx.is_type_compatible(TypeVar(0), TypeVar(1))
let _ = ctx.is_type_compatible(TypeVar(1), TypeVar(2))
let _ = ctx.is_type_compatible(TypeVar(2), TypeVar(3))
let _ = ctx.is_type_compatible(TypeVar(3), TypeVar(0))
let _ = ctx.is_type_compatible(TypeVar(2), Tuple([Double, Int]))
let t0 = ctx.type_vars.get(0).unwrap()
let t1 = ctx.type_vars.get(1).unwrap()
let t2 = ctx.type_vars.get(2).unwrap()
let t3 = ctx.type_vars.get(3).unwrap()
assert_true(t0 is Tuple([Double, Int]))
assert_true(t1 is Tuple([Double, Int]))
assert_true(t2 is Tuple([Double, Int]))
assert_true(t3 is Tuple([Double, Int]))
}
///|
test "Simple Atom Expression Type Check" {
let code = "42 3.14 true x y"
let ctx = Context::new()
ctx.type_env.set("x", { kind: Double, mutable: false })
ctx.type_env.set("y", { kind: Bool, mutable: false })
// Parse
let tokens = @parser.tokenize(code)
// Type of 42
let (a, tok_view) = @parser.parse_expr(tokens[:])
let a = ctx.check_atom_expr(a)
assert_true(a.ty is Int)
// Type of 3.14
let (a, tok_view) = @parser.parse_expr(tok_view)
let a = ctx.check_atom_expr(a)
assert_true(a.ty is Double)
// Type of true
let (a, tok_view) = @parser.parse_expr(tok_view)
let a = ctx.check_atom_expr(a)
assert_true(a.ty is Bool)
// Type of x
let (a, tok_view) = @parser.parse_expr(tok_view)
let a = ctx.check_atom_expr(a)
assert_true(a.ty is Double)
// Type of y
let (a, _) = @parser.parse_expr(tok_view)
let a = ctx.check_atom_expr(a)
assert_true(a.ty is Bool)
}
///|
test "Simple Apply Expression Type Check" {
let code =
#|42 3.14 true x y
let ctx = Context::new()
ctx.type_env.set("x", { kind: Double, mutable: false })
ctx.type_env.set("y", { kind: Bool, mutable: false })
// Parse
let tokens = @parser.tokenize(code)
// Type of 42
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tokens[:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Int)
// Type of 3.14
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view)
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Double)
// Type of true
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view)
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Bool)
// Type of x
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view)
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Double)
// Type of y
let (a, _) = @parser.parse_get_or_apply_level_expr(tok_view)
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Bool)
}
///|
test "Simple Expression Type Check" {
let code =
#|42 ; 3.14 ; true ;
#|!true ; !false ; !x ;
#|-42 ; -3.14 ; -y ;
#|1 + 3; 4 - 5 ; 6 * 7; 8.0 / 2.0; 9 % 4;
#|1 > 2; 42.0 >= y; x == true ;
let ctx = Context::new()
ctx.type_env.set("x", { kind: Bool, mutable: false })
ctx.type_env.set("y", { kind: Double, mutable: false })
// Parse
let tokens = @parser.tokenize(code)
// Type of 42
let (a, tok_view) = @parser.parse_expr(tokens[:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of 3.14
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Double)
// Type of true
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
// Type of !true
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
// Type of !false
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
// Type of !x
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
// Type of -42
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of -3.14
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Double)
// Type of -y
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Double)
// Type of 1 + 3
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of 4 - 5
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of 6 * 7
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of 8.0 / 2.0
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Double)
// Type of 9 % 4
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
//// Type of 1 > 2
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
//// Type of 42 == y
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
//// Type of x == true
let (a, _) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
}
///|
test "Atom Expression Type Check" {
let code =
#|[1, 2+3, 3+4+5, 6*7+8] (1+3, !y) Array::make(5, 0) []
let ctx = Context::new()
ctx.type_env.set("x", { kind: Double, mutable: false })
ctx.type_env.set("y", { kind: Bool, mutable: false })
// Parse
let tokens = @parser.tokenize(code)
// Type of [1, 2+3, 3+4+5, 6*7+8]
let (a, tok_view) = @parser.parse_value_level_expr(tokens[:])
let a = ctx.check_atom_expr(a)
assert_true(a.ty is Array(Int))
// Type of (1+3, !y)
let (b, tok_view) = @parser.parse_value_level_expr(tok_view)
let b = ctx.check_atom_expr(b)
assert_true(b.ty is Tuple([Int, Bool]))
// Type of Array::make(5, 0)
let (c, tok_view) = @parser.parse_value_level_expr(tok_view)
let c = ctx.check_atom_expr(c)
assert_true(c.ty is Array(Int))
// Type of []
let (d, _) = @parser.parse_value_level_expr(tok_view)
let d = ctx.check_atom_expr(d)
assert_true(d.ty is Array(TypeVar(_)))
}
///|
test "Apply Expression Type Check" {
let code =
#|arr[3] ; fact(5) ;
#|mat[3][4] ; sum(arr); max(x, y);
#|arr.push(10); mat.length();
let ctx = Context::new()
ctx.type_env.set("arr", { kind: Array(Int), mutable: false })
ctx.type_env.set("mat", { kind: Array(Array(Int)), mutable: false })
ctx.type_env.set("fact", { kind: Function([Int], Int), mutable: false })
ctx.type_env.set("sum", { kind: Function([Array(Int)], Int), mutable: false })
// Note: max is local function without type annotations
// like: fn max(a, b) { ... }
ctx.type_env.set("max", {
kind: Function([TypeVar(0), TypeVar(0)], TypeVar(0)),
mutable: false,
})
ctx.type_env.set("x", { kind: Double, mutable: true })
ctx.type_env.set("y", { kind: Double, mutable: true })
ctx.type_vars.set(0, TypeVar(0))
// Parse
let tokens = @parser.tokenize(code)
// Type of arr[3]
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tokens[:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Int)
// Type of fact(5)
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Int)
// Type of mat[3][4]
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Int)
// Type of sum(arr)
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Int)
// Type of max(x, y)
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Double)
// Type of arr.push(10)
let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Unit)
// Type of mat.length()
let (a, _) = @parser.parse_get_or_apply_level_expr(tok_view[1:])
let a = ctx.check_apply_expr(a)
assert_true(a.ty is Int)
}
///|
test "Let Mut Stmt Type Check" {
let code =
#|let mut x : Double = 42.0;
#|let mut y = 10;
#|let mut z = [];
#|let _ = z.push(3.14);
#|let _ = z.push(false); // Should fail
let ctx = Context::new()
// Parse
let tokens = @parser.tokenize(code)
// Type check for `let mut x : Double = 42.0;`
let (let_mut_stmt1, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:])
let checked_let_mut_stmt1 = ctx.check_let_mut_stmt(let_mut_stmt1)
assert_true(checked_let_mut_stmt1.ty.kind is Double)
// Type check for `let mut y = 10;`
let (let_mut_stmt2, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let checked_let_mut_stmt2 = ctx.check_let_mut_stmt(let_mut_stmt2)
assert_true(checked_let_mut_stmt2.ty.kind is Int)
// Type check for `let mut z = [];`
let (let_mut_stmt3, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let checked_let_mut_stmt3 = ctx.check_let_mut_stmt(let_mut_stmt3)
assert_true(checked_let_mut_stmt3.ty.kind is Array(TypeVar(_)))
// Type check for `let _ = z.push(3.14);`
// Parse Let Stmt
let (let_stmt4, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let checked_let_stmt4 = ctx.check_let_stmt(let_stmt4)
assert_true(checked_let_stmt4.ty is Unit)
assert_true(ctx.lookup_type("z") is Some({ kind: Array(Double), .. }))
// Type check for `let _ = z.push("Hello");` - Should fail
let (let_stmt5, _, _) = @parser.parse_stmt_or_expr_end(tok_view)
let check_result = try? ctx.check_let_stmt(let_stmt5)
assert_true(check_result is Err(_))
}
///|
test "Top Let Stmt Type Check" {
let code =
#|let x : Int = 42;
#|let y = true;
let ctx = Context::new()
// Parse
let tokens = @parser.tokenize(code)
let { top_lets, .. } = @parser.parse_program(tokens)
// Type check for `let x : Int = 42;`
let checked_top_let1 = ctx.check_top_let(top_lets["x"])
assert_true(checked_top_let1.ty.kind is Int)
let ty1 = ctx.lookup_type("x")
assert_true(ty1 is Some(t) && t.kind is Int && t.mutable == false)
// Type check for `let y = true;`
let checked_top_let2 = ctx.check_top_let(top_lets["y"])
assert_true(checked_top_let2.ty.kind is Bool)
let ty2 = ctx.lookup_type("y")
assert_true(ty2 is Some(t) && t.kind is Bool && t.mutable == false)
}
///|
test "Let Stmt Type Check" {
let code =
#|let x : Double = 42.0;
#|let (x, y) = (1, 2);
#|let z = [];
#|let mat = [[1, 2], z];
#|let _ = z.push(true);
let ctx = Context::new()
// Parse
let tokens = @parser.tokenize(code)
// Type check for `let x : Int = 42;`
let (let_stmt1, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:])
let checked_let_stmt1 = ctx.check_let_stmt(let_stmt1)
assert_true(checked_let_stmt1.ty is Double)
// Type check for `let (x, y) = (1, 2);`
let (let_stmt2, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let checked_let_stmt2 = ctx.check_let_stmt(let_stmt2)
assert_true(checked_let_stmt2.ty is Tuple([Int, Int]))
assert_true(ctx.lookup_type("x") is Some({ kind: Int, .. }))
assert_true(ctx.lookup_type("y") is Some({ kind: Int, .. }))
// Type check for `let z = [];`
let (let_stmt3, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let checked_let_stmt3 = ctx.check_let_stmt(let_stmt3)
assert_true(checked_let_stmt3.ty is Array(TypeVar(_)))
// Type check for `let mat = [[1, 2], z];`
let (let_stmt4, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let checked_let_stmt4 = ctx.check_let_stmt(let_stmt4)
assert_true(checked_let_stmt4.ty is Array(Array(Int)))
assert_true(ctx.lookup_type("mat") is Some({ kind: Array(Array(Int)), .. }))
let zty = ctx.lookup_type("z")
assert_true(
ctx.lookup_type("z") is Some({ kind: Array(Int), .. }),
msg="Type of z is \{zty}",
)
// Type check for `let _ = z.push(true);`
// Should fail
let (let_stmt5, _, _) = @parser.parse_stmt_or_expr_end(tok_view)
let check_result = try? ctx.check_let_stmt(let_stmt5)
assert_true(check_result is Err(_))
}
///|
test "Assign Stmt Type Check" {
let code =
#|let mut x = 10;
#|x = 5;
#|let arr = [1, 2, 3];
#|arr[0] = 10;
#|x = x + true; // Should fail
#|let mut a = xxx; // xxx's type is typevar
#|a = a + 33; // `a` and `xxx` should be inferred as Int
let ctx = Context::new()
// Parse
let tokens = @parser.tokenize(code)
// Type of let mut x = 10; x = 5;
let (s, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:])
let _ = ctx.check_let_mut_stmt(s)
let (a, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let a = ctx.check_assign_stmt(a)
assert_true(a.left_value.ty.kind is Int)
assert_true(a.expr.ty is Int)
// Type of `let arr = [1, 2, 3]; arr[0] = 10;`
let (s, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let _ = ctx.check_let_stmt(s)
let (a, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let a = ctx.check_assign_stmt(a)
assert_true(a.left_value.ty.kind is Int)
assert_true(a.left_value.ty.mutable is true)
assert_true(a.expr.ty is Int)
// Type of `x += true;` should fail
let (a, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let t = try? ctx.check_assign_stmt(a)
assert_true(t is Err(_))
// Type of `let mut a = xxx; a += 33;`
ctx.type_vars.set(0, TypeVar(0))
ctx.type_env.set("xxx", { kind: TypeVar(0), mutable: false })
let (s, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let _ = ctx.check_let_mut_stmt(s)
let (a, _, _) = @parser.parse_stmt_or_expr_end(tok_view)
let _ = ctx.check_assign_stmt(a)
assert_true(ctx.lookup_type("a") is Some({ kind: Int, .. }))
assert_true(ctx.lookup_type("xxx") is Some({ kind: Int, .. }))
}
///|
test "Block Expr TypeCheck Test" {
let code =
#|let arr = [];
#|
#|{
#| let (x, y) = (1, 2);
#| let mut z = x + y;
#| z = z + 10;
#| arr.push(z);
#|}
#|
#|{
#| let (x, y) = (1.0, 2.0);
#| let z = x * y;
#| let arr = [1.0, z, 5.0];
#| arr[2] = arr[1];
#| arr[2]
#|}
#|
#|{ arr[0] = 1.0; } // should cause type error
let ctx = Context::new()
// Parse
let tokens = @parser.tokenize(code)
let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:])
let _ = ctx.check_let_stmt(let_stmt)
// TypeCheck first block expr
let (block_expr1, tok_view) = @parser.parse_block_expr(tok_view)
let checked_block1 = ctx.check_block_expr(block_expr1)
assert_true(checked_block1.stmts.length() is 5) // Lilunar adds an implicit Expr(Unit) at the end
assert_true(checked_block1.ty is Unit)
// TypeCheck second block expr
let (block_expr2, tok_view) = @parser.parse_block_expr(tok_view)
let checked_block2 = ctx.check_block_expr(block_expr2)
assert_true(checked_block2.stmts.length() is 5)
assert_true(checked_block2.ty is Double)
// TypeCheck third block expr (should raise type error)
assert_true(ctx.lookup_type("arr") is Some({ kind: Array(Int), .. }))
let (block_expr3, _) = @parser.parse_block_expr(tok_view)
let checked_block3 = try? ctx.check_block_expr(block_expr3)
assert_true(checked_block3 is Err(_))
}
///|
test "Expr TypeCheck Test" {
let code =
#|arr[3] + arr[5] ; fact(5) + fib(10) ;
#|mat.data[3][4] ; sum(arr); 3.0 > max(x, y);
let ctx = Context::new()
ctx.struct_defs.set("Matrix", {
name: "Matrix",
fields: [{ name: "data", ty: { kind: Array(Array(Int)), mutable: false } }],
})
ctx.type_env.set("arr", { kind: Array(Int), mutable: false })
ctx.type_env.set("mat", { kind: Struct("Matrix"), mutable: false })
// Note: Matrix struct has a field `data` of type Array(Array(Int))
ctx.type_env.set("fact", { kind: Function([TypeVar(0)], Int), mutable: false })
ctx.type_env.set("fib", { kind: Function([Int], Int), mutable: false })
ctx.type_env.set("sum", { kind: Function([Array(Int)], Int), mutable: false })
// Note: max is local function without type annotations
ctx.type_env.set("max", {
kind: Function([TypeVar(1), TypeVar(2)], TypeVar(3)),
mutable: false,
})
ctx.type_env.set("x", { kind: Double, mutable: true })
ctx.type_env.set("y", { kind: Double, mutable: true })
ctx.type_vars.set(0, TypeVar(0))
ctx.type_vars.set(1, TypeVar(1))
ctx.type_vars.set(2, TypeVar(1))
ctx.type_vars.set(3, TypeVar(1))
// Parse
let tokens = @parser.tokenize(code)
// Type of arr[3] + arr[5]
let (a, tok_view) = @parser.parse_expr(tokens[:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of fact(5) + fib(10)
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of mat.data[3][4]
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of sum(arr)
let (a, tok_view) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Int)
// Type of 3.0 > max(x, y)
let (a, _) = @parser.parse_expr(tok_view[1:])
let a = ctx.check_expr(a)
assert_true(a.ty is Bool)
assert_true(ctx.lookup_type("x") is Some({ kind: Double, .. }))
assert_true(ctx.lookup_type("y") is Some({ kind: Double, .. }))
}
///|
test "If Expr TypeCheck Test" {
let code =
#|let arr = [];
#|
#|if a > b {
#| arr.push(1);
#|} else {
#| arr.push(2);
#|}
#|
#|if a < b {
#| arr.push(a);
#| a
#|}else if a == b {
#| arr.push(1);
#| a - b
#|} else {
#| arr.push(b);
#| b
#|}
#|
let ctx = Context::new()
ctx.type_env.set("a", { kind: Int, mutable: false })
ctx.type_env.set("b", { kind: Int, mutable: false })
// Parse
let tokens = @parser.tokenize(code)
let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:])
let _ = ctx.check_let_stmt(let_stmt)
// TypeCheck first if expr
let (if_expr1, tok_view) = @parser.parse_if_expr(tok_view)
let checked_if1 = ctx.check_if_expr(if_expr1)
assert_true(checked_if1.ty is Unit)
assert_true(ctx.lookup_type("arr") is Some({ kind: Array(Int), .. }))
// TypeCheck second if expr
let (if_expr2, _) = @parser.parse_if_expr(tok_view)
let checked_if2 = ctx.check_if_expr(if_expr2)
assert_true(checked_if2.ty is Int)
assert_true(ctx.lookup_type("a") is Some({ kind: Int, .. }))
assert_true(ctx.lookup_type("b") is Some({ kind: Int, .. }))
}
///|
test "While Stmt TypeCheck Test" {
let code =
#|let mut i = 0;
#|while i < 10 {
#| i = i + 1;
#|}
let ctx = Context::new()
// Parse
let tokens = @parser.tokenize(code)
let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:])
let _ = ctx.check_let_mut_stmt(let_stmt)
let (while_stmt, _, _) = @parser.parse_stmt_or_expr_end(tok_view)
let _ = ctx.check_while_stmt(while_stmt)
}
///|
test "Local Function Type Check Test" {
let code =
#|fn max(a, b) {
#| if a > b { a } else { b }
#|}
#|
#|let _ = max(10, 20);
#|let _ = max(1.0, 2.0); // This should raise a type error
let ctx = Context::new()
// Parse
let tokens = @parser.tokenize(code)
guard @parser.parse_stmt_or_expr_end(tokens[:])
is (LocalFunction(local_func), false, tok_view)
// TypeCheck
let _ = ctx.check_local_function(local_func)
// parse and check let statement
let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view)
let let_stmt = ctx.check_let_stmt(let_stmt)
assert_true(let_stmt.ty is Int)
let (let_stmt2, _, _) = @parser.parse_stmt_or_expr_end(tok_view)
let let_stmt2 = try? ctx.check_let_stmt(let_stmt2)
assert_true(let_stmt2 is Err(_))
}
///|
test "Struct Construct TypeCheck Test" {
let code =
#|struct Point { x: Int; y: Int; }
#|fn main {
#| let p1 = Point::{ x: 1, y: 2 };
#| let p2 = Point::{ x: 1, y: true };
#|}
let tokens = @parser.tokenize(code)
let ctx = Context::new()
let program = @parser.parse_program(tokens)
let struct_def = program.struct_defs["Point"]
guard program.top_functions["main"].body is [let_stmt1, let_stmt2, ..]
// 1. Parse and check the struct definition to populate the context
ctx.struct_defs.set("Point", ctx.check_struct_def(struct_def)) // Assumes check_struct_def works
// 2. Parse and check the valid `let p1 = ...` statement
let _ = ctx.check_let_stmt(let_stmt1)
let p1_type = ctx.lookup_type("p1")
assert_true(p1_type is Some({ kind: Struct("Point"), .. }))
// 3. Parse and check the invalid `let p2 = ...` statement
let result = try? ctx.check_let_stmt(let_stmt2)
assert_true(result is Err(_))
}
///|
test "Top Function TypeCheck Test" {
let code =
#|fn fib(n : Int) -> Int {
#| if n <= 1 {
#| return n;
#| } else {
#| return fib(n - 1) + fib(n - 2);
#| };
#| 0
#|}
let ctx = Context::new()
ctx.func_types.set("fib", Function([Int], Int))
ctx.type_env.set("fib", { kind: Function([Int], Int), mutable: false })
// parse
let tokens = @parser.tokenize(code)
let program = @parser.parse_program(tokens)
let _ = ctx.check_top_function(program.top_functions["fib"])
}
///|
test "Program TypeCheck Test" {
let code =
#|let a = 3;
#|let b = 4;
#|fn fold(arr: Array[Int], f: (Int, Int) -> Int, init: Int) -> Int {
#| let mut result = init;
#| let mut i = 0;
#| while i < arr.length() {
#| result = f(result, arr[i]);
#| }
#| result
#|}
#|
#|fn main {
#| fn max(a, b) { if a > b { a } else { b } }
#| fn min(a, b) { if a < b { a } else { b } }
#| let numbers = [a, 1, b, 1, 5, 9, 2, 6, 5];
#| let maximum = fold(numbers, max, -1000);
#| let minimum = fold(numbers, min, 1000);
#| let max_min_diff = maximum - minimum;
#| print_int(max_min_diff);
#|}
let ctx = Context::new()
ctx.type_env.set("print_int", { kind: Function([Int], Unit), mutable: false })
ctx.func_types.set("print_int", Function([Int], Unit))
// Type check the program
let tokens = @parser.tokenize(code)
let program = @parser.parse_program(tokens)
let _ = ctx.check_program(program)
}
///|
test "TypeCheck Test" {
let code =
#|let a = 3;
#|let b = 4;
#|fn fold(arr: Array[Int], f: (Int, Int) -> Int, init: Int) -> Int {
#| let mut result = init;
#| let mut i = 0;
#| while i < arr.length() {
#| result = f(result, arr[i]);
#| }
#| result
#|}
#|
#|fn main {
#| fn max(a, b) { if a > b { a } else { b } }
#| fn min(a, b) { if a < b { a } else { b } }
#| let numbers = [a, 1, b, 1, 5, 9, 2, 6, 5];
#| let maximum = fold(numbers, max, -1000);
#| let minimum = fold(numbers, min, 1000);
#| let max_min_diff = maximum - minimum;
#|}
let tokens = @parser.tokenize(code)
let program = @parser.parse_program(tokens)
let program = typecheck(program)
let program_str = program.to_string()
assert_false(program_str.contains("TypeVar"))
}

View File

@@ -0,0 +1,241 @@
///|
pub fn typecheck(program : @parser.Program) -> Program raise TypeCheckError {
let ctx = Context::new()
let checked_program = ctx.check_program(program)
ctx.substitute_type_var(checked_program)
}
///|
pub fn Context::substitute_type_var(
self : Context,
program : Program,
) -> Program {
for name, top_let in program.top_lets {
program.top_lets.set(name, {
name,
expr: self.substitute_type_var_for_expr(top_let.expr),
ty: {
kind: self.deref_type_var(top_let.ty.kind),
mutable: top_let.ty.mutable,
},
})
}
for name, top_func in program.top_functions {
program.top_functions.set(name, {
fname: name,
param_list: top_func.param_list.map(param => Param::{
name: param.name,
ty: self.deref_type_var(param.ty),
}),
ret_ty: self.deref_type_var(top_func.ret_ty),
body: self.substitute_type_var_for_block_expr(top_func.body),
})
}
program
}
///|
pub fn Context::substitute_type_var_for_expr(
self : Context,
expr : Expr,
) -> Expr {
{
kind: match expr.kind {
Or(left, right) =>
Or(
self.substitute_type_var_for_expr(left),
self.substitute_type_var_for_expr(right),
)
And(left, right) =>
And(
self.substitute_type_var_for_expr(left),
self.substitute_type_var_for_expr(right),
)
MulDivRem(op, left, right) =>
MulDivRem(
op,
self.substitute_type_var_for_expr(left),
self.substitute_type_var_for_expr(right),
)
AddSub(op, left, right) =>
AddSub(
op,
self.substitute_type_var_for_expr(left),
self.substitute_type_var_for_expr(right),
)
Compare(op, left, right) =>
Compare(
op,
self.substitute_type_var_for_expr(left),
self.substitute_type_var_for_expr(right),
)
NegExpr(inner) => NegExpr(self.substitute_type_var_for_expr(inner))
NotExpr(inner) => NotExpr(self.substitute_type_var_for_expr(inner))
BlockExpr({ stmts, ty }) =>
BlockExpr({
stmts: stmts.map(stmt => {
kind: match stmt.kind {
ReturnStmt(expr) =>
ReturnStmt(self.substitute_type_var_for_expr(expr))
ExprStmt(expr) =>
ExprStmt(self.substitute_type_var_for_expr(expr))
WhileStmt({ cond, body }) =>
WhileStmt({
cond: self.substitute_type_var_for_expr(cond),
body: self.substitute_type_var_for_block_expr(body),
})
AssignStmt({ left_value, expr }) =>
AssignStmt({
left_value: self.substitute_type_var_for_left_value(
left_value,
),
expr: self.substitute_type_var_for_expr(expr),
})
LetStmt({ pattern, ty, expr }) =>
LetStmt({
pattern: self.substitute_type_var_for_let_pattern(pattern),
ty: self.deref_type_var(ty),
expr: self.substitute_type_var_for_expr(expr),
})
LetMutStmt({ name, ty, expr }) =>
LetMutStmt({
name,
ty: {
kind: self.deref_type_var(ty.kind),
mutable: ty.mutable,
},
expr: self.substitute_type_var_for_expr(expr),
})
LocalFunction({ fname, param_list, ret_ty, body }) =>
LocalFunction({
fname,
param_list: param_list.map(param => (
param.0,
{
kind: self.deref_type_var(param.1.kind),
mutable: param.1.mutable,
},
)),
ret_ty: {
kind: self.deref_type_var(ret_ty.kind),
mutable: ret_ty.mutable,
},
body: self.substitute_type_var_for_block_expr(body),
})
},
}),
ty: self.deref_type_var(ty),
})
ApplyExpr(apply) =>
ApplyExpr(self.substitute_type_var_for_apply_expr(apply))
IfExpr({ cond, then_block, else_block, ty }) =>
IfExpr({
cond: self.substitute_type_var_for_expr(cond),
then_block: self.substitute_type_var_for_block_expr(then_block),
else_block: else_block.map(eb => self.substitute_type_var_for_expr(eb)),
ty: self.deref_type_var(ty),
})
},
ty: self.deref_type_var(expr.ty),
}
}
///|
fn Context::substitute_type_var_for_block_expr(
self : Context,
block : BlockExpr,
) -> BlockExpr {
guard self.substitute_type_var_for_expr({ kind: BlockExpr(block), ty: Unit }).kind // only use kind
is BlockExpr(block)
block
}
///|
fn Context::substitute_type_var_for_left_value(
self : Context,
left_value : LeftValue,
) -> LeftValue {
{
kind: match left_value.kind {
Ident(name) => Ident(name)
ArrayAccess(base, index_expr) =>
ArrayAccess(
self.substitute_type_var_for_left_value(base),
self.substitute_type_var_for_expr(index_expr),
)
FieldAccess(base, field_name) =>
FieldAccess(self.substitute_type_var_for_left_value(base), field_name)
},
ty: {
kind: self.deref_type_var(left_value.ty.kind),
mutable: left_value.ty.mutable,
},
}
}
///|
fn Context::substitute_type_var_for_let_pattern(
self : Context,
pattern : Pattern,
) -> Pattern {
{
kind: match pattern.kind {
Ident(name) => Ident(name)
Wildcard => Wildcard
Tuple(patterns) =>
Tuple(patterns.map(p => self.substitute_type_var_for_let_pattern(p)))
},
}
}
///|
fn Context::substitute_type_var_for_apply_expr(
self : Context,
apply_expr : ApplyExpr,
) -> ApplyExpr {
{
kind: match apply_expr.kind {
AtomExpr({ kind, ty }) =>
AtomExpr({
kind: match kind {
Int(v) => Int(v)
Double(v) => Double(v)
Bool(v) => Bool(v)
Unit => Unit
Ident(name) => Ident(name)
Tuple(elems) =>
Tuple(elems.map(elem => self.substitute_type_var_for_expr(elem)))
Array(elems) =>
Array(elems.map(elem => self.substitute_type_var_for_expr(elem)))
ArrayMake(size_expr, init_expr) =>
ArrayMake(
self.substitute_type_var_for_expr(size_expr),
self.substitute_type_var_for_expr(init_expr),
)
StructConstruct({ name, fields }) =>
StructConstruct({
name,
fields: fields.map(field => {
let (field_name, field_expr) = field
(field_name, self.substitute_type_var_for_expr(field_expr))
}),
})
},
ty: self.deref_type_var(ty),
})
ArrayAccess(base, index_expr) =>
ArrayAccess(
self.substitute_type_var_for_apply_expr(base),
self.substitute_type_var_for_expr(index_expr),
)
FieldAccess(base, field_name) =>
FieldAccess(self.substitute_type_var_for_apply_expr(base), field_name)
Call(callee, args) =>
Call(
self.substitute_type_var_for_apply_expr(callee),
args.map(arg => self.substitute_type_var_for_expr(arg)),
)
},
ty: self.deref_type_var(apply_expr.ty),
}
}

163
src/typecheck/typedef.mbt Normal file
View File

@@ -0,0 +1,163 @@
///|
pub(all) struct Type {
kind : TypeKind
mutable : Bool
} derive(Show)
///|
pub(all) enum TypeKind {
Unit
Bool
Int
Double
Tuple(Array[TypeKind])
Array(TypeKind)
Function(Array[TypeKind], TypeKind)
Struct(String)
Any
TypeVar(Int)
} derive(Eq, Hash)
///|
pub fn Context::check_parser_type(
self : Context,
ty : @parser.Type,
mutable? : Bool = false,
) -> Type raise TypeCheckError {
match ty {
UserDefined(name) => {
if self.type_env.get("$Generic$\{name}") is Some(t) {
return t
}
if !self.struct_defs.contains(name) {
raise TypeCheckError("Undefined type: \{name}")
}
{ kind: Struct(name), mutable }
}
Function(param_types, return_type) => {
let param_types = param_types.map(param_type => self.check_parser_type(
param_type,
).kind)
let return_type = self.check_parser_type(return_type).kind
{ kind: Function(param_types, return_type), mutable }
}
Tuple(types) =>
{
kind: Tuple(types.map(type_ => self.check_parser_type(type_).kind)),
mutable,
}
Array(elem_type) =>
{ kind: Array(self.check_parser_type(elem_type).kind), mutable }
Double => { kind: Double, mutable }
Int => { kind: Int, mutable }
Bool => { kind: Bool, mutable }
Unit => { kind: Unit, mutable }
Generic(_) => ...
}
}
///|
pub fn Context::deref_type_var(self : Context, ty : TypeKind) -> TypeKind {
match ty {
TypeVar(r) if self.type_vars[r] == TypeVar(r) => ty
TypeVar(r) => self.deref_type_var(self.type_vars[r]) // 要求不能有环
Tuple(elem_types) =>
Tuple(elem_types.map(typekind => self.deref_type_var(typekind)))
Array(elem_type) => Array(self.deref_type_var(elem_type))
Function(param_types, return_type) =>
Function(
param_types.map(typekind => self.deref_type_var(typekind)),
self.deref_type_var(return_type),
)
_ => ty
}
}
///|
pub fn Context::is_type_compatible(
self : Context,
a : TypeKind,
b : TypeKind,
) -> Bool {
fn set_type_vars(self : Context, id : Int, target : TypeKind) -> Unit {
if target == TypeVar(id) {
return
}
for var_id, var_type in self.type_vars {
if var_type == TypeVar(id) {
self.type_vars[var_id] = target
set_type_vars(self, var_id, target)
}
}
}
let deref_a = self.deref_type_var(a)
let deref_b = self.deref_type_var(b)
match (deref_a, deref_b) {
(TypeVar(id_a), TypeVar(id_b)) => {
if id_a != id_b {
set_type_vars(self, id_a, TypeVar(id_b))
}
true
}
(TypeVar(id), other) | (other, TypeVar(id)) => {
set_type_vars(self, id, other)
true
}
(Unit, Unit) => true
(Bool, Bool) => true
(Int, Int) => true
(Double, Double) => true
(Tuple(types_a), Tuple(types_b)) => {
if types_a.length() != types_b.length() {
return false
}
let mut result = true
for i in 0..<types_a.length() {
if !self.is_type_compatible(types_a[i], types_b[i]) {
result = false
}
}
result
}
(Array(elem_a), Array(elem_b)) => self.is_type_compatible(elem_a, elem_b)
(Function(params_a, ret_a), Function(params_b, ret_b)) => {
if params_a.length() != params_b.length() {
return false
}
let mut result = true
for i in 0..<params_a.length() {
if !self.is_type_compatible(params_a[i], params_b[i]) {
result = false
}
}
self.is_type_compatible(ret_a, ret_b) && result
}
(Struct(name_a), Struct(name_b)) => name_a == name_b
(Any, _) | (_, Any) => true
_ => false
}
}
///|
pub impl Show for TypeKind with output(self, logger) {
let s = match self {
Unit => "Unit"
Bool => "Bool"
Int => "Int"
Double => "Double"
Tuple(types) => {
let inner = types.map(typekind => typekind.to_string()).join(", ")
"(\{inner})"
}
Array(elem_type) => "Array[\{elem_type.to_string()}]"
Function(param_types, return_type) => {
let params = param_types.map(typekind => typekind.to_string()).join(", ")
"(\{params}) -> \{return_type.to_string()}"
}
Struct(name) => "\{name}"
Any => "Any"
TypeVar(name) => "TypeVar(\{name})"
}
logger.write_string(s)
}