From 7c9ba92769fa788dc3bbcc974c3a1965a1bd862f Mon Sep 17 00:00:00 2001 From: Lil-Ran Date: Tue, 4 Nov 2025 02:37:01 +0800 Subject: [PATCH] feat: type check --- src/parser/ast.mbt | 234 +++++--- src/parser/ast_wbtest.mbt | 2 +- src/parser/tokenize.mbt | 2 +- src/typecheck/expr.mbt | 177 ++++++ src/typecheck/expr_apply.mbt | 95 +++ src/typecheck/expr_atom.mbt | 106 ++++ src/typecheck/expr_block.mbt | 23 + src/typecheck/expr_if.mbt | 45 ++ src/typecheck/moon.pkg.json | 5 + src/typecheck/program.mbt | 153 +++++ src/typecheck/stmt.mbt | 43 ++ src/typecheck/stmt_assign.mbt | 84 +++ src/typecheck/stmt_let.mbt | 77 +++ src/typecheck/stmt_let_mut.mbt | 29 + src/typecheck/stmt_local_function.mbt | 52 ++ src/typecheck/stmt_while.mbt | 23 + src/typecheck/struct_def.mbt | 36 ++ src/typecheck/top_function.mbt | 63 ++ src/typecheck/top_let.mbt | 27 + src/typecheck/typecheck_test.mbt | 826 ++++++++++++++++++++++++++ src/typecheck/typechecker.mbt | 241 ++++++++ src/typecheck/typedef.mbt | 163 +++++ 22 files changed, 2418 insertions(+), 88 deletions(-) create mode 100644 src/typecheck/expr.mbt create mode 100644 src/typecheck/expr_apply.mbt create mode 100644 src/typecheck/expr_atom.mbt create mode 100644 src/typecheck/expr_block.mbt create mode 100644 src/typecheck/expr_if.mbt create mode 100644 src/typecheck/moon.pkg.json create mode 100644 src/typecheck/program.mbt create mode 100644 src/typecheck/stmt.mbt create mode 100644 src/typecheck/stmt_assign.mbt create mode 100644 src/typecheck/stmt_let.mbt create mode 100644 src/typecheck/stmt_let_mut.mbt create mode 100644 src/typecheck/stmt_local_function.mbt create mode 100644 src/typecheck/stmt_while.mbt create mode 100644 src/typecheck/struct_def.mbt create mode 100644 src/typecheck/top_function.mbt create mode 100644 src/typecheck/top_let.mbt create mode 100644 src/typecheck/typecheck_test.mbt create mode 100644 src/typecheck/typechecker.mbt create mode 100644 src/typecheck/typedef.mbt diff --git a/src/parser/ast.mbt b/src/parser/ast.mbt index 65db1dd..e0ffba5 100644 --- a/src/parser/ast.mbt +++ b/src/parser/ast.mbt @@ -15,7 +15,7 @@ pub(all) enum Type { } derive(Show) ///| -enum Literal { +pub(all) enum Literal { Unit Bool(Bool) Int(Int) @@ -23,20 +23,20 @@ enum Literal { } derive(Show) ///| -enum AddSubOp { +pub(all) enum AddSubOp { Add Sub } derive(Show) ///| -enum MulDivRemOp { +pub(all) enum MulDivRemOp { Mul Div Rem } derive(Show) ///| -enum Pattern { +pub(all) enum Pattern { Wildcard Identifier(String) Literal(Literal) @@ -45,7 +45,7 @@ enum Pattern { } derive(Show) ///| -enum Expr { +pub(all) enum Expr { Or(Expr, Expr) And(Expr, Expr) Compare(CompareOperator, Expr, Expr) @@ -69,7 +69,7 @@ enum Expr { } derive(Show) ///| -struct Function { +pub(all) struct Function { id : String user_defined_type : Type? params : Array[(String, Type?)] @@ -78,15 +78,15 @@ struct Function { } derive(Show) ///| -enum Binding { +pub(all) enum Binding { Identifier(String) Wildcard + Tuple(Array[Binding]) } derive(Show) ///| -enum Stmt { +pub(all) enum Stmt { Let(Binding, Type?, Expr) - LetTuple(Array[Binding], Type?, Expr) LetMut(String, Type?, Expr) Assign(Expr, Expr) While(Expr, Array[Stmt]) @@ -96,28 +96,28 @@ enum Stmt { } derive(Show) ///| -struct TopLet { +pub(all) struct TopLet { id : String type_ : Type? expr : Expr } derive(Show) ///| -struct StructDef { +pub(all) struct StructDef { id : String user_defined_type : Type? fields : Array[(String, Type)] } derive(Show) ///| -struct EnumDef { +pub(all) struct EnumDef { id : String user_defined_type : Type? variants : Array[(String, Array[Type])] } derive(Show) ///| -struct Program { +pub(all) struct Program { top_lets : Map[String, TopLet] top_functions : Map[String, Function] struct_defs : Map[String, StructDef] @@ -125,7 +125,7 @@ struct Program { } derive(Show) ///| -fn parse_type( +pub fn parse_type( tokens : ArrayView[Token], ) -> (Type, ArrayView[Token]) raise ParseError { match tokens { @@ -136,7 +136,7 @@ fn parse_type( [Array, LBracket, .. rest] => { let (elem_type, rest) = parse_type(rest) guard rest is [RBracket, .. rest] else { - raise ParseError("Expected ']' after array type") + raise ParseError("Expected ']' after array type, got \{rest[0]}") } (Array(elem_type), rest) } @@ -156,7 +156,7 @@ fn parse_type( [RParen, ..] => continue r _ => raise ParseError( - "Expected ',' or ')' or ')' '->' in tuple/function type", + "Expected ',' or ')' or ')' '->' in tuple/function type, got \{r[0]}", ) } continue r @@ -166,21 +166,26 @@ fn parse_type( [UpperIdentifier(name), LBracket, .. rest] => { let (type_arg, rest) = parse_type(rest) guard rest is [RBracket, .. rest] else { - raise ParseError("Expected ']' after generic type argument") + raise ParseError( + "Expected ']' after generic type argument, got \{rest[0]}", + ) } (Generic(name, type_arg), rest) } [UpperIdentifier(name), .. rest] => (UserDefined(name), rest) - _ => raise ParseError("Unexpected token while parsing type") + tokens => + raise ParseError("Unexpected token while parsing type: \{tokens[0]}") } } ///| -fn parse_struct_decl( +pub fn parse_struct_decl( tokens : ArrayView[Token], ) -> (StructDef, ArrayView[Token]) raise ParseError { guard tokens is [Struct, UpperIdentifier(id), .. rest] else { - raise ParseError("Expected upper case struct name after 'struct'") + raise ParseError( + "Expected upper case struct name after 'struct', got \{tokens[1]}", + ) } let (user_defined_type, rest) = if rest is [LBracket, UpperIdentifier(type_), RBracket, .. r] { @@ -189,7 +194,7 @@ fn parse_struct_decl( (None, rest) } guard rest is [LCurlyBracket, .. rest] else { - raise ParseError("Expected '{' after struct name") + raise ParseError("Expected '{' after struct name, got \{rest[0]}") } let fields = [] loop rest { @@ -200,19 +205,25 @@ fn parse_struct_decl( match r { [Semicolon, .. r] => continue r [RCurlyBracket, ..] => continue r - _ => raise ParseError("Expected ';' or '}' after struct field") + _ => + raise ParseError( + "Expected ';' or '}' after struct field, got \{r[0]}", + ) } } - _ => raise ParseError("Unexpected token in struct field list") + rest => + raise ParseError("Unexpected token in struct field list, got \{rest[0]}") } } ///| -fn parse_enum_decl( +pub fn parse_enum_decl( tokens : ArrayView[Token], ) -> (EnumDef, ArrayView[Token]) raise ParseError { guard tokens is [Enum, UpperIdentifier(id), .. rest] else { - raise ParseError("Expected upper case enum name after 'enum'") + raise ParseError( + "Expected upper case enum name after 'enum', got \{tokens[1]}", + ) } let (user_defined_type, rest) = if rest is [LBracket, UpperIdentifier(type_), RBracket, .. r] { @@ -221,7 +232,7 @@ fn parse_enum_decl( (None, rest) } guard rest is [LCurlyBracket, .. rest] else { - raise ParseError("Expected '{' after enum name") + raise ParseError("Expected '{' after enum name, got \{rest[0]}") } let variants = [] loop rest { @@ -239,7 +250,7 @@ fn parse_enum_decl( [RParen, ..] => continue r _ => raise ParseError( - "Expected ',' or ')' in enum variant type list", + "Expected ',' or ')' in enum variant type list, got \{r[0]}", ) } } @@ -251,15 +262,19 @@ fn parse_enum_decl( match r { [Semicolon, .. r] => continue r [RCurlyBracket, ..] => continue r - _ => raise ParseError("Expected ';' or '}' after enum variant") + _ => + raise ParseError( + "Expected ';' or '}' after enum variant, got \{r[0]}", + ) } } - _ => raise ParseError("Unexpected token in enum variant list") + rest => + raise ParseError("Unexpected token in enum variant list, got \{rest[0]}") } } ///| -fn parse_optional_type_annotation( +pub fn parse_optional_type_annotation( tokens : ArrayView[Token], ) -> (Type?, ArrayView[Token]) raise ParseError { if tokens is [Colon, .. r] { @@ -271,18 +286,18 @@ fn parse_optional_type_annotation( } ///| -fn parse_let_stmt_type_expr( +pub fn parse_let_stmt_type_expr( tokens : ArrayView[Token], ) -> (Type?, Expr, ArrayView[Token]) raise ParseError { let (type_, rest) = parse_optional_type_annotation(tokens) guard rest is [Assign, .. rest] else { raise ParseError( - "Expected '=' after identifier or type annotation in let declaration", + "Expected '=' after identifier or type annotation in let declaration, got \{rest[0]}", ) } let (expr, rest) = parse_expr(rest) guard rest is [Semicolon, .. rest] else { - raise ParseError("Expected ';' after let statement") + raise ParseError("Expected ';' after let statement, got \{rest[0]}") } (type_, expr, rest) } @@ -291,7 +306,7 @@ fn parse_let_stmt_type_expr( /// Returns `(Stmt, is_end_expr, rest_tokens)`. /// Semicolon is consumed for statements. /// End expr is wrapped in `Stmt::Expr`. -fn parse_stmt_or_expr_end( +pub fn parse_stmt_or_expr_end( tokens : ArrayView[Token], ) -> (Stmt, Bool, ArrayView[Token]) raise ParseError { match tokens { @@ -306,19 +321,25 @@ fn parse_stmt_or_expr_end( match binding { LowerIdentifier(name) | UpperIdentifier(name) => Identifier(name) Wildcard => Wildcard - _ => raise ParseError("Unreachable") + _ => panic() }, ) match r { [Comma, .. r] => continue r [RParen, ..] => continue r - _ => raise ParseError("Expected ',' or ')' after let binding") + _ => + raise ParseError( + "Expected ',' or ')' after let binding, got \{r[0]}", + ) } } - _ => raise ParseError("Unexpected token in let tuple binding list") + rest => + raise ParseError( + "Unexpected token in let tuple binding list, got \{rest[0]}", + ) } let (type_, expr, rest) = parse_let_stmt_type_expr(rest) - (LetTuple(bindings, type_, expr), false, rest) + (Let(Tuple(bindings), type_, expr), false, rest) } // let_mut_stmt @@ -332,7 +353,7 @@ fn parse_stmt_or_expr_end( let binding : Binding = match b { LowerIdentifier(name) | UpperIdentifier(name) => Identifier(name) Wildcard => Wildcard - _ => raise ParseError("Unreachable") + _ => panic() } let (type_, expr, rest) = parse_let_stmt_type_expr(rest) (Let(binding, type_, expr), false, rest) @@ -350,10 +371,15 @@ fn parse_stmt_or_expr_end( [Comma, .. r] => continue r [RParen, ..] => continue r _ => - raise ParseError("Expected ',' or ')' after function parameter") + raise ParseError( + "Expected ',' or ')' after function parameter, got \{r[0]}", + ) } } - _ => raise ParseError("Unexpected token in function parameter list") + rest => + raise ParseError( + "Unexpected token in function parameter list, got \{rest[0]}", + ) } let (return_type, rest) = if rest is [Arrow, .. r] { let (t, r) = parse_type(r) @@ -374,7 +400,7 @@ fn parse_stmt_or_expr_end( [Return, .. rest] => { let (expr, rest) = parse_expr(rest) guard rest is [Semicolon, .. rest] else { - raise ParseError("Expected ';' after return statement") + raise ParseError("Expected ';' after return statement, got \{rest[0]}") } (Return(expr), false, rest) } @@ -383,13 +409,19 @@ fn parse_stmt_or_expr_end( [While, .. rest] => { let (cond_expr, rest) = parse_expr(rest) let stmts = [] + guard rest is [LCurlyBracket, .. rest] else { + raise ParseError("Expected '{' after while condition, got \{rest[0]}") + } loop rest { - [RCurlyBracket, .. r] => (While(cond_expr, stmts), false, r) + [RCurlyBracket, .. r] => { + stmts.push(Expr(Literal(Unit))) + (While(cond_expr, stmts), false, r) + } r => { let (stmt, is_end_expr, rest) = parse_stmt_or_expr_end(r) if is_end_expr { raise ParseError( - "Unexpected end expression in while statement body", + "Unexpected end expression \{stmt} in while statement body", ) } stmts.push(stmt) @@ -410,28 +442,35 @@ fn parse_stmt_or_expr_end( _ => false } guard valid_left_value else { - raise ParseError("Invalid left-hand side in assignment") + raise ParseError("Invalid left-hand side in assignment: \{expr}") } let (rhs_expr, r) = parse_expr(r) guard r is [Semicolon, .. r] else { - raise ParseError("Expected ';' after assignment statement") + raise ParseError( + "Expected ';' after assignment statement, got \{r[0]}", + ) } (Assign(expr, rhs_expr), false, r) } [Semicolon, .. r] => (Expr(expr), false, r) [RCurlyBracket, ..] => (Expr(expr), true, rest) - _ => raise ParseError("Expected ';' or '}' after expression statement") + rest => + raise ParseError( + "Expected ';' or '}' after expression statement, got \{rest[0]}", + ) } } } } ///| -fn parse_block_expr( +pub fn parse_block_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { guard tokens is [LCurlyBracket, .. rest] else { - raise ParseError("Expected '{' at start of block expression") + raise ParseError( + "Expected '{' at start of block expression, got \{tokens[0]}", + ) } let stmts = [] let rest = loop rest { @@ -440,7 +479,9 @@ fn parse_block_expr( stmts.push(stmt) if is_end_expr { guard r is [RCurlyBracket, .. r] else { - raise ParseError("Expected '}' at end of block expression") + raise ParseError( + "Expected '}' at end of block expression, got \{r[0]}", + ) } break r } else if r is [RCurlyBracket, .. r] { @@ -454,11 +495,13 @@ fn parse_block_expr( } ///| -fn parse_if_expr( +pub fn parse_if_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { guard tokens is [If, .. rest] else { - raise ParseError("Expected 'if' at start of if expression") + raise ParseError( + "Expected 'if' at start of if expression, got \{tokens[0]}", + ) } let (cond, rest) = parse_expr(rest) let (then_branch, rest) = parse_block_expr(rest) @@ -477,7 +520,7 @@ fn parse_if_expr( } ///| -fn parse_value_level_expr( +pub fn parse_value_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { match tokens { @@ -486,11 +529,15 @@ fn parse_value_level_expr( [Array, DoubleColon, LowerIdentifier("make"), LParen, .. rest] => { let (size_expr, rest) = parse_expr(rest) guard rest is [Comma, .. rest] else { - raise ParseError("Expected ',' after size expression in array make") + raise ParseError( + "Expected ',' after size expression in array make, got \{rest[0]}", + ) } let (init_expr, rest) = parse_expr(rest) guard rest is [RParen, .. rest] else { - raise ParseError("Expected ')' after init expression in array make") + raise ParseError( + "Expected ')' after init expression in array make, got \{rest[0]}", + ) } (ArrayMake(size_expr, init_expr), rest) } @@ -506,10 +553,16 @@ fn parse_value_level_expr( match r { [Comma, .. r] => continue r [RCurlyBracket, ..] => continue r - _ => raise ParseError("Expected ',' or '}' after struct field") + _ => + raise ParseError( + "Expected ',' or '}' after struct field, got \{r[0]}", + ) } } - _ => raise ParseError("Unexpected token in struct construction") + rest => + raise ParseError( + "Unexpected token in struct construction, got \{rest[0]}", + ) } } @@ -547,8 +600,10 @@ fn parse_value_level_expr( } else { (Tuple(exprs), r) } - _ => - raise ParseError("Expected ',' or ')' in tuple or grouped expression") + r => + raise ParseError( + "Expected ',' or ')' in tuple or grouped expression, got \{r[0]}", + ) } } @@ -562,10 +617,10 @@ fn parse_value_level_expr( elements.push(element) match r { [Comma, .. r] => continue r - [RBracket, ..] => (Array(elements), r) + [RBracket, ..] => continue r _ => raise ParseError( - "Expected ',' or ']' in array literal expression", + "Expected ',' or ']' in array literal expression, got \{r[0]}", ) } } @@ -578,12 +633,13 @@ fn parse_value_level_expr( // identifier_expr [LowerIdentifier(name) | UpperIdentifier(name), .. rest] => (Identifier(name), rest) - _ => raise ParseError("Unsupported value level expression") + tokens => + raise ParseError("Unsupported value level expression: \{tokens[:3]}") } } ///| -fn parse_get_or_apply_level_expr( +pub fn parse_get_or_apply_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { let (value, rest) = parse_value_level_expr(tokens) @@ -592,7 +648,7 @@ fn parse_get_or_apply_level_expr( [LBracket, .. r] => { let (index, r) = parse_expr(r) guard r is [RBracket, .. r] else { - raise ParseError("Expected ']' after index expression") + raise ParseError("Expected ']' after index expression, got \{r[0]}") } result = IndexAccess(result, index) continue r @@ -606,8 +662,11 @@ fn parse_get_or_apply_level_expr( args.push(arg) match r { [Comma, .. r] => continue r - [RParen, ..] => break r - _ => raise ParseError("Expected ',' or ')' in function call args") + [RParen, ..] => continue r + _ => + raise ParseError( + "Expected ',' or ')' in function call args, got \{r[0]}", + ) } } } @@ -623,7 +682,7 @@ fn parse_get_or_apply_level_expr( } ///| -fn parse_if_level_expr( +pub fn parse_if_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { match tokens { @@ -634,7 +693,7 @@ fn parse_if_level_expr( } ///| -fn parse_mul_div_level_expr( +pub fn parse_mul_div_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { let (first, rest) = parse_if_level_expr(tokens) @@ -660,7 +719,7 @@ fn parse_mul_div_level_expr( } ///| -fn parse_add_sub_level_expr( +pub fn parse_add_sub_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { let (first, rest) = parse_mul_div_level_expr(tokens) @@ -681,7 +740,7 @@ fn parse_add_sub_level_expr( } ///| -fn parse_compare_level_expr( +pub fn parse_compare_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { let (first, rest) = parse_add_sub_level_expr(tokens) @@ -695,7 +754,7 @@ fn parse_compare_level_expr( } ///| -fn parse_and_level_expr( +pub fn parse_and_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { let (first, rest) = parse_compare_level_expr(tokens) @@ -711,7 +770,8 @@ fn parse_and_level_expr( } ///| -fn parse_or_level_expr( +#alias(parse_expr) +pub fn parse_or_level_expr( tokens : ArrayView[Token], ) -> (Expr, ArrayView[Token]) raise ParseError { let (first, rest) = parse_and_level_expr(tokens) @@ -726,13 +786,6 @@ fn parse_or_level_expr( } } -///| -fn parse_expr( - tokens : ArrayView[Token], -) -> (Expr, ArrayView[Token]) raise ParseError { - parse_or_level_expr(tokens) -} - ///| pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError { let top_lets = Map::new() @@ -745,13 +798,15 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError { let (type_, rest) = parse_optional_type_annotation(rest) guard rest is [Assign, .. rest] else { raise ParseError( - "Expected '=' after identifier or type annotation in let declaration", + "Expected '=' after identifier or type annotation in let declaration, got \{rest[0]}", ) } let (expr, rest) = parse_expr(rest) top_lets[id] = { id, type_, expr } guard rest is [Semicolon, .. rest] else { - raise ParseError("Expected ';' after top let declaration") + raise ParseError( + "Expected ';' after top let declaration, got \{rest[0]}", + ) } continue rest } @@ -782,9 +837,9 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError { } [UpperIdentifier(id) | LowerIdentifier(id), LParen, .. r] => (None, id, r) - _ => + rest => raise ParseError( - "Expected function name (with optional type) after 'fn'", + "Expected function name (with optional type) after 'fn', got \{rest[0]}", ) } let params = [] @@ -797,13 +852,20 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError { [Comma, .. r] => continue r [RParen, ..] => continue r _ => - raise ParseError("Expected ',' or ')' after function parameter") + raise ParseError( + "Expected ',' or ')' after function parameter, got \{r[0]}", + ) } } - _ => raise ParseError("Unexpected token in function parameter list") + rest => + raise ParseError( + "Unexpected token in function parameter list, got \{rest[0]}", + ) } guard rest is [Arrow, .. rest] else { - raise ParseError("Expected '->' after function parameter list") + raise ParseError( + "Expected '->' after function parameter list, got \{rest[0]}", + ) } let (return_type, rest) = parse_type(rest) guard parse_block_expr(rest) is (Block(body), rest) diff --git a/src/parser/ast_wbtest.mbt b/src/parser/ast_wbtest.mbt index 11537e2..f5c61aa 100644 --- a/src/parser/ast_wbtest.mbt +++ b/src/parser/ast_wbtest.mbt @@ -209,7 +209,7 @@ test "parse_get_or_apply_level_expr" { EOF, ]), content=( - #|(FunctionCall(Identifier("f"), [Literal(Int(1))]), [RParen, EOF]) + #|(FunctionCall(Identifier("f"), [Literal(Int(1))]), [EOF]) ), ) } diff --git a/src/parser/tokenize.mbt b/src/parser/tokenize.mbt index c819bb0..df581ba 100644 --- a/src/parser/tokenize.mbt +++ b/src/parser/tokenize.mbt @@ -257,7 +257,7 @@ pub fn tokenize(input : String) -> Array[Token] raise TokenizeError { continue rest } [] => tokens.push(EOF) - [c, ..] => raise TokenizeError("Unexpected character: \{c}") + str => raise TokenizeError("Unexpected token: \{try? str[:20]} ...") } tokens } diff --git a/src/typecheck/expr.mbt b/src/typecheck/expr.mbt new file mode 100644 index 0000000..06b59b0 --- /dev/null +++ b/src/typecheck/expr.mbt @@ -0,0 +1,177 @@ +///| +pub(all) struct Expr { + kind : ExprKind + ty : TypeKind +} derive(Show) + +///| +pub(all) enum ExprKind { + ApplyExpr(ApplyExpr) + BlockExpr(BlockExpr) + NotExpr(Expr) + NegExpr(Expr) + Compare(CompareOperator, Expr, Expr) + AddSub(AddSubOp, Expr, Expr) + MulDivRem(MulDivRemOp, Expr, Expr) + And(Expr, Expr) + Or(Expr, Expr) + IfExpr(IfExpr) +} derive(Show) + +///| +pub(all) enum CompareOperator { + Equal + NotEqual + GreaterEqual + LessEqual + Greater + Less +} derive(Show) + +///| +pub(all) enum AddSubOp { + Add + Sub +} derive(Show) + +///| +pub(all) enum MulDivRemOp { + Mul + Div + Rem +} derive(Show) + +///| +pub fn Context::check_expr( + self : Self, + expr : @parser.Expr, +) -> Expr raise TypeCheckError { + match expr { + Literal(_) + | Identifier(_) + | Tuple(_) + | Array(_) + | ArrayMake(_) + | StructConstruct(_) + | IndexAccess(_) + | FieldAccess(_) + | FunctionCall(_) => { + let apply_expr = self.check_apply_expr(expr) + { kind: ApplyExpr(apply_expr), ty: apply_expr.ty } + } + Not(inner_expr) => { + let inner_expr = self.check_expr(inner_expr) + guard inner_expr.ty is Bool else { + raise TypeCheckError("Operand of '!' must be Bool.") + } + { kind: NotExpr(inner_expr), ty: Bool } + } + Neg(inner_expr) => { + let inner_expr = self.check_expr(inner_expr) + match inner_expr.ty { + Int => { kind: NegExpr(inner_expr), ty: Int } + Double => { kind: NegExpr(inner_expr), ty: Double } + _ => raise TypeCheckError("Operand of unary '-' must be Int or Double.") + } + } + Compare(op, left, right) => { + let left = self.check_expr(left) + let right = self.check_expr(right) + if !self.is_type_compatible(left.ty, right.ty) { + raise TypeCheckError( + "Operands of comparison must be of compatible types.", + ) + } + guard self.deref_type_var(left.ty) + is (Int | Double | Bool | Unit | TypeVar(_)) else { + // XXX: Constraint TypeVar to these types + raise TypeCheckError( + "Operands of comparison must be Int, Double, Bool, or Unit.", + ) + } + { kind: Compare(compare_op_map(op), left, right), ty: Bool } + } + AddSub(op, left, right) => { + let left = self.check_expr(left) + let right = self.check_expr(right) + if !self.is_type_compatible(left.ty, right.ty) { + raise TypeCheckError("Operands of '+'/'-' must be of compatible types.") + } + match self.deref_type_var(left.ty) { + Int => { kind: AddSub(add_sub_op_map(op), left, right), ty: Int } + Double => { kind: AddSub(add_sub_op_map(op), left, right), ty: Double } + _ => raise TypeCheckError("Operands of '+'/'-' must be Int or Double.") + } + } + MulDivRem(op, left, right) => { + let left = self.check_expr(left) + let right = self.check_expr(right) + if !self.is_type_compatible(left.ty, right.ty) { + raise TypeCheckError( + "Operands of '*','/','%' must be of compatible types.", + ) + } + match self.deref_type_var(left.ty) { + Int => { kind: MulDivRem(mul_div_rem_op_map(op), left, right), ty: Int } + Double => + { kind: MulDivRem(mul_div_rem_op_map(op), left, right), ty: Double } + _ => + raise TypeCheckError("Operands of '*','/','%' must be Int or Double.") + } + } + And(left, right) => { + let left = self.check_expr(left) + let right = self.check_expr(right) + guard left.ty is Bool && right.ty is Bool else { + raise TypeCheckError("Operands of '&&' must be Bool.") + } + { kind: And(left, right), ty: Bool } + } + Or(left, right) => { + let left = self.check_expr(left) + let right = self.check_expr(right) + guard left.ty is Bool && right.ty is Bool else { + raise TypeCheckError("Operands of '||' must be Bool.") + } + { kind: Or(left, right), ty: Bool } + } + If(_) => { + let if_expr = self.check_if_expr(expr) + { kind: IfExpr(if_expr), ty: if_expr.ty } + } + Block(_) => { + let block_expr = self.check_block_expr(expr) + { kind: BlockExpr(block_expr), ty: block_expr.ty } + } + EnumConstruct(_) | Match(_) => ... + } +} + +///| +fn compare_op_map(op : @parser.CompareOperator) -> CompareOperator { + match op { + Equal => Equal + NotEqual => NotEqual + GreaterEqual => GreaterEqual + LessEqual => LessEqual + Greater => Greater + Less => Less + } +} + +///| +fn add_sub_op_map(op : @parser.AddSubOp) -> AddSubOp { + match op { + Add => Add + Sub => Sub + } +} + +///| +fn mul_div_rem_op_map(op : @parser.MulDivRemOp) -> MulDivRemOp { + match op { + Mul => Mul + Div => Div + Rem => Rem + } +} diff --git a/src/typecheck/expr_apply.mbt b/src/typecheck/expr_apply.mbt new file mode 100644 index 0000000..fe55190 --- /dev/null +++ b/src/typecheck/expr_apply.mbt @@ -0,0 +1,95 @@ +///| +pub(all) struct ApplyExpr { + kind : ApplyExprKind + ty : TypeKind +} derive(Show) + +///| +pub(all) enum ApplyExprKind { + AtomExpr(AtomExpr) + ArrayAccess(ApplyExpr, Expr) + FieldAccess(ApplyExpr, String) + Call(ApplyExpr, Array[Expr]) +} derive(Show) + +///| +pub fn Context::check_apply_expr( + self : Self, + apply_expr : @parser.Expr, +) -> ApplyExpr raise TypeCheckError { + match apply_expr { + Literal(_) + | Identifier(_) + | Tuple(_) + | Array(_) + | ArrayMake(_) + | StructConstruct(_) => { + let atom_expr = self.check_atom_expr(apply_expr) + { kind: AtomExpr(atom_expr), ty: atom_expr.ty } + } + IndexAccess(base_expr, index_expr) => { + let base_typed = self.check_apply_expr(base_expr) + guard base_typed.ty is Array(elem_type) else { + raise TypeCheckError("Attempting to index a non-array type.") + } + let index_typed = self.check_expr(index_expr) + guard index_typed.ty is Int else { + raise TypeCheckError("Array index must be of type Int.") + } + { kind: ArrayAccess(base_typed, index_typed), ty: elem_type } + } + FieldAccess(base_expr, field_name) => { + let base_typed = self.check_apply_expr(base_expr) + match base_typed.ty { + Array(t) => + { + kind: FieldAccess(base_typed, field_name), + ty: match field_name { + "length" => Function([], Int) + "push" => Function([t], Unit) + _ => raise TypeCheckError("Array has no method '\{field_name}'.") + }, + } + Struct(struct_name) => { + let struct_def = self.struct_defs.get(struct_name).unwrap() + let field_type = struct_def + .get_field_type(field_name) + .or_error( + TypeCheckError( + "Struct '\{struct_name}' has no field '\{field_name}'.", + ), + ) + { kind: FieldAccess(base_typed, field_name), ty: field_type.kind } + } + t => raise TypeCheckError("Attempting to access field of type '\{t}'.") + } + } + FunctionCall(callee_expr, arg_exprs) => { + let callee_typed = self.check_apply_expr(callee_expr) + guard callee_typed.ty is Function(param_types, return_type) else { + raise TypeCheckError("Attempting to call a non-function type.") + } + let args_count = arg_exprs.length() + guard param_types.length() == args_count else { + raise TypeCheckError( + "Function called with incorrect number of arguments.", + ) + } + let args_typed = [] + for i in 0.. raise TypeCheckError("Unsupported apply expression \{e}") + } +} diff --git a/src/typecheck/expr_atom.mbt b/src/typecheck/expr_atom.mbt new file mode 100644 index 0000000..3eb4b02 --- /dev/null +++ b/src/typecheck/expr_atom.mbt @@ -0,0 +1,106 @@ +///| +pub(all) struct AtomExpr { + kind : AtomExprKind + ty : TypeKind +} derive(Show) + +///| +pub(all) enum AtomExprKind { + Int(Int) // 1, 42, etc + Double(Double) // 1.0, 3.14, etc + Bool(Bool) // true | false + Unit // () + Ident(String) // var + Tuple(Array[Expr]) // (expr, expr, ...) + Array(Array[Expr]) // [expr, expr, ...] + ArrayMake(Expr, Expr) // Array::make(size, init) + StructConstruct(StructConstructExpr) // StructName::{ field: expr, ... } +} derive(Show) + +///| +pub(all) struct StructConstructExpr { + name : String + fields : Array[(String, Expr)] +} derive(Show) + +///| +pub fn Context::check_atom_expr( + self : Self, + atom_expr : @parser.Expr, +) -> AtomExpr raise TypeCheckError { + match atom_expr { + Literal(Int(v)) => { kind: Int(v), ty: Int } + Literal(Double(v)) => { kind: Double(v), ty: Double } + Literal(Bool(v)) => { kind: Bool(v), ty: Bool } + Literal(Unit) => { kind: Unit, ty: Unit } + Identifier(name) => { + let var_type = self.lookup_type(name) + match var_type { + Some(ty) => { kind: Ident(name), ty: ty.kind } + None => raise TypeCheckError("Undefined variable: \{name}") + } + } + Tuple(elems) => { + let exprs = [] + let types = [] + for expr in elems { + let typed_expr = self.check_expr(expr) + exprs.push(typed_expr) + types.push(typed_expr.ty) + } + { kind: Tuple(exprs), ty: Tuple(types) } + } + Array([]) => { kind: Array([]), ty: Array(self.new_type_var()) } + Array([first_elem, .. other_elems]) => { + let first_elem = self.check_expr(first_elem) + let elem_type = first_elem.ty + let exprs = [first_elem] + for expr in other_elems { + let typed_expr = self.check_expr(expr) + if !self.is_type_compatible(elem_type, typed_expr.ty) { + raise TypeCheckError( + "Incompatible types in array expression: \{typed_expr.ty} is not \{elem_type}", + ) + } + exprs.push(typed_expr) + } + { kind: Array(exprs), ty: Array(elem_type) } + } + ArrayMake(size_expr, init_expr) => { + let size_typed = self.check_expr(size_expr) + guard size_typed.ty is Int else { + raise TypeCheckError("Array size must be of type Int.") + } + let init_typed = self.check_expr(init_expr) + { kind: ArrayMake(size_typed, init_typed), ty: Array(init_typed.ty) } + } + StructConstruct(name, fields) => { + let def = self.struct_defs + .get(name) + .or_error(TypeCheckError("Undefined struct: \{name}")) + let field_exprs = [] + for field in fields { + let (field_name, field_expr) = field + let expected_type = def + .get_field_type(field_name) + .or_error( + TypeCheckError( + "Struct '\{name}' has no field named '\{field_name}'.", + ), + ) + let typed_expr = self.check_expr(field_expr) + if !self.is_type_compatible(expected_type.kind, typed_expr.ty) { + raise TypeCheckError( + "Type mismatch for field '\{field_name}' in struct '\{name}'.", + ) + } + field_exprs.push((field_name, typed_expr)) + } + guard field_exprs.length() == def.fields.length() else { + raise TypeCheckError("Struct '\{name}' construction missing fields.") + } + { kind: StructConstruct({ name, fields: field_exprs }), ty: Struct(name) } + } + e => raise TypeCheckError("Unsupported atom expression \{e}") + } +} diff --git a/src/typecheck/expr_block.mbt b/src/typecheck/expr_block.mbt new file mode 100644 index 0000000..658e177 --- /dev/null +++ b/src/typecheck/expr_block.mbt @@ -0,0 +1,23 @@ +///| +pub(all) struct BlockExpr { + stmts : Array[Stmt] + ty : TypeKind +} derive(Show) + +///| +pub fn Context::check_block_expr( + self : Context, + block_expr : @parser.Expr, +) -> BlockExpr raise TypeCheckError { + guard block_expr is Block(stmts) else { + raise TypeCheckError("Expected block expression") + } + self.enter_scope() + let checked_stmts = stmts.map(stmt => self.check_stmt(stmt)) + self.exit_scope() + let stmts_count = checked_stmts.length() + guard stmts_count > 0 && checked_stmts[stmts_count - 1].kind is ExprStmt(expr) else { + return { stmts: checked_stmts, ty: Unit } + } + { stmts: checked_stmts, ty: expr.ty } +} diff --git a/src/typecheck/expr_if.mbt b/src/typecheck/expr_if.mbt new file mode 100644 index 0000000..476a175 --- /dev/null +++ b/src/typecheck/expr_if.mbt @@ -0,0 +1,45 @@ +///| +pub(all) struct IfExpr { + cond : Expr + then_block : BlockExpr + else_block : Expr? + ty : TypeKind +} derive(Show) + +///| +pub fn Context::check_if_expr( + self : Context, + if_expr : @parser.Expr, +) -> IfExpr raise TypeCheckError { + guard if_expr is If(cond, then_block, else_block) else { + raise TypeCheckError("Expected if expression") + } + let cond = self.check_expr(cond) + guard self.is_type_compatible(cond.ty, Bool) else { + raise TypeCheckError("if condition must be of type Bool, got \{cond.ty}") + } + guard self.check_expr(then_block) is { kind: BlockExpr(then_block), .. } else { + raise TypeCheckError("then block must be a block expression") + } + let else_block = match else_block { + Some(else_expr) => { + let checked_else = self.check_expr(else_expr) + guard self.is_type_compatible(checked_else.ty, then_block.ty) else { + raise TypeCheckError( + "else block type \{checked_else.ty} does not match then block type \{then_block.ty}", + ) + } + Some(checked_else) + } + None => { + guard self.is_type_compatible(then_block.ty, Unit) else { + raise TypeCheckError( + "if expression without else block must have then block of type Unit, got \{then_block.ty}", + ) + } + None + } + } + let ty = self.deref_type_var(then_block.ty) + { cond, then_block, else_block, ty } +} diff --git a/src/typecheck/moon.pkg.json b/src/typecheck/moon.pkg.json new file mode 100644 index 0000000..82ff965 --- /dev/null +++ b/src/typecheck/moon.pkg.json @@ -0,0 +1,5 @@ +{ + "import": [ + "Lil-Ran/lilunar/parser" + ] +} \ No newline at end of file diff --git a/src/typecheck/program.mbt b/src/typecheck/program.mbt new file mode 100644 index 0000000..4691d2d --- /dev/null +++ b/src/typecheck/program.mbt @@ -0,0 +1,153 @@ +///| +pub(all) suberror TypeCheckError String derive(Show) + +///| +pub(all) struct Env { + local_ : Map[String, Type] + parent : Env? +} + +///| +pub fn Env::new(parent? : Env? = None) -> Env { + { local_: Map::new(), parent } +} + +///| +pub fn Env::get(self : Env, name : String) -> Type? { + match self.local_.get(name) { + Some(t) => Some(t) + None => + match self.parent { + Some(p) => p.get(name) + None => None + } + } +} + +///| +pub fn Env::set(self : Env, name : String, t : Type) -> Unit { + self.local_.set(name, t) +} + +///| +pub(all) struct Context { + mut type_env : Env + type_vars : Map[Int, TypeKind] + struct_defs : Map[String, StructDef] + func_types : Map[String, TypeKind] + mut current_func_ret_ty : TypeKind? +} + +///| +pub fn Context::new() -> Context { + { + type_env: Env::new(), + type_vars: Map::new(), + struct_defs: Map::new(), + func_types: Map::new(), + current_func_ret_ty: None, + } +} + +///| +pub fn Context::new_type_var(self : Context) -> TypeKind { + let type_var_id = self.type_vars.length() + self.type_vars.set(type_var_id, TypeVar(type_var_id)) + TypeVar(type_var_id) +} + +///| +pub fn Context::lookup_type(self : Context, name : String) -> Type? { + let looked = self.type_env.get(name) + if looked is None { + return None + } + let { kind, mutable } = looked.unwrap() + let deref_kind = self.deref_type_var(kind) + Some({ kind: deref_kind, mutable }) +} + +///| +pub fn Context::enter_scope(self : Self) -> Unit { + let sub_env = Env::new(parent=Some(self.type_env)) + self.type_env = sub_env +} + +///| +pub fn Context::exit_scope(self : Context) -> Unit { + self.type_env = match self.type_env.parent { + Some(p) => p + None => self.type_env + } +} + +///| +pub fn Context::set_current_func_ret_ty(self : Context, ty : TypeKind) -> Unit { + self.current_func_ret_ty = Some(ty) +} + +///| +pub(all) struct Program { + top_lets : Map[String, TopLet] + top_functions : Map[String, TopFunction] + struct_defs : Map[String, StructDef] +} derive(Show) + +///| +pub fn Context::check_program( + self : Context, + program : @parser.Program, +) -> Program raise TypeCheckError { + self.collect_struct_names(program) + self.collect_function_types(program) + let struct_defs : Map[String, StructDef] = Map::new() + for name, struct_def in program.struct_defs { + let checked_struct_def = self.check_struct_def(struct_def) + struct_defs.set(name, checked_struct_def) + self.struct_defs.set(name, checked_struct_def) + } + let top_functions : Map[String, TopFunction] = Map::new() + for name, func in program.top_functions { + let checked_func = self.check_top_function(func) + top_functions.set(name, checked_func) + self.func_types.set( + name, + Function(checked_func.param_list.map(p => p.ty), checked_func.ret_ty), + ) + } + let top_lets : Map[String, TopLet] = Map::new() + for name, let_def in program.top_lets { + let checked_let = self.check_top_let(let_def) + top_lets.set(name, checked_let) + self.type_env.set(name, checked_let.ty) + } + { top_lets, top_functions, struct_defs } +} + +///| +pub fn Context::collect_struct_names( + self : Context, + program : @parser.Program, +) -> Unit raise TypeCheckError { + for name, _ in program.struct_defs { + if self.struct_defs.contains(name) { + raise TypeCheckError("Duplicate struct definition: \{name}") + } + self.struct_defs.set(name, { name, fields: [] }) + } +} + +///| +pub fn Context::collect_function_types( + self : Context, + program : @parser.Program, +) -> Unit { + for name, _ in program.top_functions { + let func_type = self.new_type_var() + self.func_types.set(name, func_type) + } + for name, _ in program.top_lets { + let let_type = self.new_type_var() + self.type_env.set(name, { kind: let_type, mutable: false }) + } +} diff --git a/src/typecheck/stmt.mbt b/src/typecheck/stmt.mbt new file mode 100644 index 0000000..6fd57df --- /dev/null +++ b/src/typecheck/stmt.mbt @@ -0,0 +1,43 @@ +///| +pub(all) struct Stmt { + kind : StmtKind +} derive(Show) + +///| +pub(all) enum StmtKind { + LetStmt(LetStmt) + LetMutStmt(LetMutStmt) + AssignStmt(AssignStmt) + WhileStmt(WhileStmt) + ExprStmt(Expr) + ReturnStmt(Expr) + LocalFunction(LocalFunction) +} derive(Show) + +///| +pub fn Context::check_stmt( + self : Context, + stmt : @parser.Stmt, +) -> Stmt raise TypeCheckError { + match stmt { + Return(expr) => { + guard self.current_func_ret_ty is Some(expected) else { + raise TypeCheckError("return statement outside of function") + } + let actual = self.check_expr(expr) + if !self.is_type_compatible(actual.ty, expected) { + raise TypeCheckError( + "return type mismatch: expected \{expected}, got \{actual.ty}", + ) + } + { kind: ReturnStmt(actual) } + } + Expr(expr) => { kind: ExprStmt(self.check_expr(expr)) } + While(_) => { kind: WhileStmt(self.check_while_stmt(stmt)) } + Assign(_) => { kind: AssignStmt(self.check_assign_stmt(stmt)) } + LetMut(_) => { kind: LetMutStmt(self.check_let_mut_stmt(stmt)) } + Let(_) => { kind: LetStmt(self.check_let_stmt(stmt)) } + LocalFunction(func) => + { kind: LocalFunction(self.check_local_function(func)) } + } +} diff --git a/src/typecheck/stmt_assign.mbt b/src/typecheck/stmt_assign.mbt new file mode 100644 index 0000000..768eb54 --- /dev/null +++ b/src/typecheck/stmt_assign.mbt @@ -0,0 +1,84 @@ +///| +pub(all) struct LeftValue { + kind : LeftValueKind + ty : Type +} derive(Show) + +///| +pub(all) enum LeftValueKind { + Ident(String) + ArrayAccess(LeftValue, Expr) + FieldAccess(LeftValue, String) +} derive(Show) + +///| +pub(all) struct AssignStmt { + left_value : LeftValue + expr : Expr +} derive(Show) + +///| +pub fn Context::check_left_value( + self : Context, + lv : @parser.Expr, +) -> LeftValue raise TypeCheckError { + match lv { + Identifier(name) => { + guard self.lookup_type(name) is Some(var_info) else { + raise TypeCheckError("Undefined variable '\{name}'.") + } + { kind: Ident(name), ty: var_info } + } + IndexAccess(base, index_expr) => { + let base_typed = self.check_left_value(base) + let index_typed = self.check_expr(index_expr) + guard index_typed.ty is Int else { + raise TypeCheckError("Array index must be of type Int.") + } + guard base_typed.ty.kind is Array(elem_type) else { + raise TypeCheckError("Base of array access is not an array.") + } + { + kind: ArrayAccess(base_typed, index_typed), + ty: { kind: elem_type, mutable: true }, // MiniMoonBit Array elements are always mutable + } + } + FieldAccess(base, field_name) => { + let base_typed = self.check_left_value(base) + guard base_typed.ty.kind is Struct(name) else { + raise TypeCheckError("Base of field access is not a struct.") + } + let field_type = self.struct_defs + .get(name) + .or_error(TypeCheckError("Undefined struct: \{name}")) + .get_field_type(field_name) + .or_error(TypeCheckError("Struct has no field named '\{field_name}'.")) + { + kind: FieldAccess(base_typed, field_name), + ty: { kind: field_type.kind, mutable: base_typed.ty.mutable }, + } + } + lv => raise TypeCheckError("Invalid left value expression: \{lv}") + } +} + +///| +pub fn Context::check_assign_stmt( + self : Context, + assign_stmt : @parser.Stmt, +) -> AssignStmt raise TypeCheckError { + guard assign_stmt is Assign(left_value, expr) else { + raise TypeCheckError("Expected AssignStmt.") + } + let left_value_typed = self.check_left_value(left_value) + guard left_value_typed.ty.mutable else { + raise TypeCheckError("Cannot assign to immutable left value \{left_value}.") + } + let expr_typed = self.check_expr(expr) + guard self.is_type_compatible(expr_typed.ty, left_value_typed.ty.kind) else { + raise TypeCheckError( + "Cannot assign value of type \{expr_typed.ty} to variable of type \{left_value_typed.ty}.", + ) + } + { left_value: left_value_typed, expr: expr_typed } +} diff --git a/src/typecheck/stmt_let.mbt b/src/typecheck/stmt_let.mbt new file mode 100644 index 0000000..dc4d0b3 --- /dev/null +++ b/src/typecheck/stmt_let.mbt @@ -0,0 +1,77 @@ +///| +pub(all) struct LetStmt { + pattern : Pattern + ty : TypeKind + expr : Expr +} derive(Show) + +///| +pub(all) struct Pattern { + kind : PatternKind +} derive(Show, Eq) + +///| +pub(all) enum PatternKind { + Wildcard + Ident(String) + Tuple(Array[Pattern]) +} derive(Show, Eq) + +///| +pub fn Context::check_let_stmt( + self : Context, + let_stmt : @parser.Stmt, +) -> LetStmt raise TypeCheckError { + guard let_stmt is Let(pattern, ty, expr) else { + raise TypeCheckError("Expected LetStmt.") + } + let expr_typed = self.check_expr(expr) + let mut result_ty = expr_typed.ty + if ty is Some(ty) { + result_ty = self.check_parser_type(ty).kind + guard self.is_type_compatible(result_ty, expr_typed.ty) else { + raise TypeCheckError( + "Type annotation does not match expression type for pattern '\{pattern}'.", + ) + } + } + let pattern = parser_pattern_map(pattern) + result_ty = self.deref_type_var(result_ty) + self.set_pattern_types(pattern, result_ty) + { pattern, ty: result_ty, expr: expr_typed } +} + +///| +pub fn Context::set_pattern_types( + self : Context, + pattern : Pattern, + ty : TypeKind, +) -> Unit raise TypeCheckError { + match pattern.kind { + Ident(name) => self.type_env.set(name, { kind: ty, mutable: false }) + Wildcard => () + Tuple(patterns) => { + guard ty is Tuple(elem_types) else { + raise TypeCheckError("Type mismatch for tuple pattern.") + } + guard patterns.length() == elem_types.length() else { + raise TypeCheckError("Tuple pattern length does not match type.") + } + for i in 0.. Pattern { + match pattern { + Identifier(name) => { kind: Ident(name) } + Wildcard => { kind: Wildcard } + Tuple(bindings) => { + let patterns = bindings.map(binding => parser_pattern_map(binding)) + { kind: Tuple(patterns) } + } + } +} diff --git a/src/typecheck/stmt_let_mut.mbt b/src/typecheck/stmt_let_mut.mbt new file mode 100644 index 0000000..c3debea --- /dev/null +++ b/src/typecheck/stmt_let_mut.mbt @@ -0,0 +1,29 @@ +///| +pub(all) struct LetMutStmt { + name : String + ty : Type + expr : Expr +} derive(Show) + +///| +pub fn Context::check_let_mut_stmt( + self : Context, + stmt : @parser.Stmt, +) -> LetMutStmt raise TypeCheckError { + guard stmt is LetMut(name, ty, expr) else { + raise TypeCheckError("Expected LetMutStmt.") + } + let expr_typed = self.check_expr(expr) + let mut result_ty = expr_typed.ty + if ty is Some(ty) { + result_ty = self.check_parser_type(ty).kind + guard self.is_type_compatible(result_ty, expr_typed.ty) else { + raise TypeCheckError( + "Type annotation does not match expression type for variable '\{name}'.", + ) + } + } + result_ty = self.deref_type_var(result_ty) + self.type_env.set(name, { kind: result_ty, mutable: true }) + { name, ty: { kind: result_ty, mutable: true }, expr: expr_typed } +} diff --git a/src/typecheck/stmt_local_function.mbt b/src/typecheck/stmt_local_function.mbt new file mode 100644 index 0000000..6906601 --- /dev/null +++ b/src/typecheck/stmt_local_function.mbt @@ -0,0 +1,52 @@ +///| +pub(all) struct LocalFunction { + fname : String + param_list : Array[(String, Type)] + ret_ty : Type + body : BlockExpr +} derive(Show) + +///| +pub fn Context::check_local_function( + self : Context, + func : @parser.Function, +) -> LocalFunction raise TypeCheckError { + let param_list = func.params.map(param => { + let (param_name, param_type) = param + match param_type { + Some(ty) => + (param_name, { kind: self.check_parser_type(ty).kind, mutable: false }) + None => (param_name, { kind: self.new_type_var(), mutable: false }) + } + }) + let ret_type = match func.return_type { + Some(ty) => self.check_parser_type(ty).kind + None => self.new_type_var() + } + let func_type = Function(param_list.map(p => p.1.kind), ret_type) + self.type_env.set(func.id, { kind: func_type, mutable: false }) + // + self.enter_scope() + self.current_func_ret_ty = Some(ret_type) + // + for param in param_list { + let (param_name, param_type) = param + self.type_env.set(param_name, param_type) + } + // + let checked_body = self.check_block_expr(Block(func.body)) + if !self.is_type_compatible(checked_body.ty, ret_type) { + raise TypeCheckError( + "Function '\{func.id}' return type mismatch: expected \{ret_type}, got \{checked_body.ty}", + ) + } + // + self.exit_scope() + // + { + fname: func.id, + param_list, + ret_ty: { kind: ret_type, mutable: false }, + body: checked_body, + } +} diff --git a/src/typecheck/stmt_while.mbt b/src/typecheck/stmt_while.mbt new file mode 100644 index 0000000..e154390 --- /dev/null +++ b/src/typecheck/stmt_while.mbt @@ -0,0 +1,23 @@ +///| +pub(all) struct WhileStmt { + cond : Expr + body : BlockExpr +} derive(Show) + +///| +pub fn Context::check_while_stmt( + self : Context, + while_stmt : @parser.Stmt, +) -> WhileStmt raise TypeCheckError { + guard while_stmt is While(cond_expr, body_expr) else { + raise TypeCheckError("Expected while statement") + } + let cond = self.check_expr(cond_expr) + guard self.is_type_compatible(cond.ty, Bool) else { + raise TypeCheckError("while condition must be of type Bool, got \{cond.ty}") + } + guard self.check_expr(Block(body_expr)) is { kind: BlockExpr(body_block), .. } else { + raise TypeCheckError("while body must be a block expression") + } + { cond, body: body_block } +} diff --git a/src/typecheck/struct_def.mbt b/src/typecheck/struct_def.mbt new file mode 100644 index 0000000..4d27e2c --- /dev/null +++ b/src/typecheck/struct_def.mbt @@ -0,0 +1,36 @@ +///| +pub(all) struct StructDef { + name : String + fields : Array[StructField] +} derive(Show) + +///| +pub(all) struct StructField { + name : String + ty : Type +} derive(Show) + +///| +pub fn StructDef::get_field_type(self : Self, field_name : String) -> Type? { + for field in self.fields { + if field.name == field_name { + return Some(field.ty) + } + } + return None +} + +///| +pub fn Context::check_struct_def( + self : Self, + struct_def : @parser.StructDef, +) -> StructDef raise TypeCheckError { + { + name: struct_def.id, + fields: struct_def.fields.map(field => { + let (field_name, field_type) = field + let field_type = self.check_parser_type(field_type) + { name: field_name, ty: field_type } + }), + } +} diff --git a/src/typecheck/top_function.mbt b/src/typecheck/top_function.mbt new file mode 100644 index 0000000..db8c3e6 --- /dev/null +++ b/src/typecheck/top_function.mbt @@ -0,0 +1,63 @@ +///| +pub(all) struct Param { + name : String + ty : TypeKind +} derive(Show) + +///| +pub(all) struct TopFunction { + fname : String + param_list : Array[Param] + ret_ty : TypeKind + body : BlockExpr +} derive(Show) + +///| +pub fn Context::check_top_function( + self : Context, + func : @parser.Function, +) -> TopFunction raise TypeCheckError { + // XXX: 目前的泛型不是真正的泛型,只是只能归一成单个类型的类型变量 + if func.user_defined_type is Some(UserDefined(udt)) { + self.type_env.set("$Generic$\{udt}", { + kind: self.new_type_var(), + mutable: false, + }) + } + // + let param_list : Array[Param] = func.params.map(param => { + let (param_name, param_type) = param + match param_type { + Some(ty) => { name: param_name, ty: self.check_parser_type(ty).kind } + None => { name: param_name, ty: self.new_type_var() } + } + }) + let ret_ty = match func.return_type { + Some(ty) => self.check_parser_type(ty).kind + None => self.new_type_var() + } + let func_type = Function(param_list.map(p => p.ty), ret_ty) + self.type_env.set(func.id, { kind: func_type, mutable: false }) + // + self.enter_scope() + self.current_func_ret_ty = Some(ret_ty) + // + for param in param_list { + let { name, ty } = param + self.type_env.set(name, { kind: ty, mutable: false }) + } + // + let checked_body = self.check_block_expr(Block(func.body)) + if !self.is_type_compatible(checked_body.ty, ret_ty) { + raise TypeCheckError( + "Function '\{func.id}' return type mismatch: expected \{ret_ty}, got \{checked_body.ty}", + ) + } + // + self.exit_scope() + if func.user_defined_type is Some(UserDefined(udt)) { + self.type_env.local_.remove("$Generic$\{udt}") + } + // + { fname: func.id, param_list, ret_ty, body: checked_body } +} diff --git a/src/typecheck/top_let.mbt b/src/typecheck/top_let.mbt new file mode 100644 index 0000000..9a45540 --- /dev/null +++ b/src/typecheck/top_let.mbt @@ -0,0 +1,27 @@ +///| +pub(all) struct TopLet { + name : String + ty : Type + expr : Expr +} derive(Show) + +///| +pub fn Context::check_top_let( + self : Context, + top_let : @parser.TopLet, +) -> TopLet raise TypeCheckError { + let { id: name, type_: ty, expr } = top_let + let expr_typed = self.check_expr(expr) + let mut result_ty = expr_typed.ty + if ty is Some(ty) { + result_ty = self.check_parser_type(ty).kind + guard self.is_type_compatible(result_ty, expr_typed.ty) else { + raise TypeCheckError( + "Type annotation does not match expression type for variable '\{name}'.", + ) + } + } + result_ty = self.deref_type_var(result_ty) + self.type_env.set(name, { kind: result_ty, mutable: false }) + { name, ty: { kind: result_ty, mutable: false }, expr: expr_typed } +} diff --git a/src/typecheck/typecheck_test.mbt b/src/typecheck/typecheck_test.mbt new file mode 100644 index 0000000..5e3d0d6 --- /dev/null +++ b/src/typecheck/typecheck_test.mbt @@ -0,0 +1,826 @@ +///| +test "TypeCheck Normal Type Test" { + let ctx = Context::new() + let code = + #|Unit Int Bool Double + #|(Int, Bool) Array[Int] + #|Array[(Int, Double)] + #|(Int, Bool) -> Double + let tokens = @parser.tokenize(code) + let (t, tok_view) = @parser.parse_type(tokens) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Unit) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Int) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Bool) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Double) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Tuple([t1, t2]) && t1 is Int && t2 is Bool) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Array(t1) && t1 is Int) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true( + t.kind is Array(t1) && t1 is Tuple([t2, t3]) && t2 is Int && t3 is Double, + ) + let (t, _) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true( + t.kind is Function([t1, t2], t3) && t1 is Int && t2 is Bool && t3 is Double, + ) +} + +///| +test "TypeCheck Defined Type Test" { + let ctx = Context::new() + ctx.struct_defs.set("Point", { name: "Point", fields: [] }) + ctx.struct_defs.set("Circle", { name: "Circle", fields: [] }) + ctx.struct_defs.set("Rectangle", { name: "Rectangle", fields: [] }) + // Note: No Triangle + let code = + #|Point Circle Rectangle Triangle + #|Array[Point] (Point, Circle) + #|(Rectangle, Point) -> Circle + #|Array[Triangle] + let tokens = @parser.tokenize(code) + let (t, tok_view) = @parser.parse_type(tokens) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Struct("Point")) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Struct("Circle")) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Struct("Rectangle")) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = try? ctx.check_parser_type(t) + assert_true(t is Err(_)) // Triangle is not defined + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true(t.kind is Array(t1) && t1 is Struct("Point")) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true( + t.kind is Tuple([t1, t2]) && t1 is Struct("Point") && t2 is Struct("Circle"), + ) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = ctx.check_parser_type(t) + assert_true( + t.kind is Function([t1, t2], t3) && + t1 is Struct("Rectangle") && + t2 is Struct("Point") && + t3 is Struct("Circle"), + ) + let (t, _) = @parser.parse_type(tok_view) + let t = try? ctx.check_parser_type(t) + assert_true(t is Err(_)) // Triangle is not defined +} + +///| +test "Struct Definition Typecheck" { + let code = + #|struct Point { x: Int; y: Int; } + #|struct Queue { data: Array[Int]; front: Int; back: Int; } + let tokens = @parser.tokenize(code) + let ctx = Context::new() + // Parse + let (struct_def, tok_view) = @parser.parse_struct_decl(tokens[:]) + // Typecheck for `struct Point { x: Int; y: Int }` + let struct_def = ctx.check_struct_def(struct_def) + assert_true(struct_def.name is "Point") + assert_true(struct_def.fields.length() is 2) + assert_true( + struct_def.fields is [f1, f2] && + f1 is { name: "x", ty: { kind: Int, mutable: false } } && + f2 is { name: "y", ty: { kind: Int, mutable: false } }, + ) + // Typecheck for `struct Queue { data: Array[Int]; front: Int; back: Int }` + let (struct_def, _) = @parser.parse_struct_decl(tok_view) + let struct_def = ctx.check_struct_def(struct_def) + assert_true(struct_def.name is "Queue") + assert_true(struct_def.fields.length() is 3) + assert_true( + struct_def.fields is [f1, f2, f3] && + f1 is { name: "data", ty: { kind: Array(Int), mutable: false } } && + f2 is { name: "front", ty: { kind: Int, mutable: false } } && + f3 is { name: "back", ty: { kind: Int, mutable: false } }, + ) +} + +///| +test "Type Var Test - 1" { + let ctx = Context::new() + ctx.type_vars.set(0, TypeVar(0)) + ctx.type_vars.set(1, TypeVar(1)) + ctx.type_vars.set(2, TypeVar(2)) + ctx.type_vars.set(3, TypeVar(3)) + // TVar(0) = Double + let t1 = ctx.is_type_compatible(TypeVar(0), Double) + assert_true(t1 is true) + assert_true(ctx.type_vars.get(0).unwrap() is Double) + // TVar(1) = TVar(0) + let t2 = ctx.is_type_compatible(TypeVar(1), TypeVar(0)) + assert_true(t2 is true) + assert_true(ctx.type_vars.get(1).unwrap() is Double) + // TVar(2) = TVar(3) + let t3 = ctx.is_type_compatible(TypeVar(2), TypeVar(3)) + assert_true(t3 is true) + // TVAr(3) = Int, therefore TVar(2) = Int + let t4 = ctx.is_type_compatible(TypeVar(3), Int) + assert_true(t4 is true) + assert_true(ctx.type_vars.get(2).unwrap() is Int) + assert_true(ctx.type_vars.get(3).unwrap() is Int) + let t5 = ctx.is_type_compatible(TypeVar(2), Double) + assert_true(t5 is false) +} + +///| +/// +/// 0 == 1 == 2 == 3 +/// 2 == (Double, Int) +/// +/// 0, 1, 2, 3 should all resolve to (Double, Int) +test "Type Var Test - 2" { + let ctx = Context::new() + ctx.type_vars.set(0, TypeVar(0)) + ctx.type_vars.set(1, TypeVar(1)) + ctx.type_vars.set(2, TypeVar(2)) + ctx.type_vars.set(3, TypeVar(3)) + // + let _ = ctx.is_type_compatible(TypeVar(0), TypeVar(1)) + let _ = ctx.is_type_compatible(TypeVar(1), TypeVar(2)) + let _ = ctx.is_type_compatible(TypeVar(2), TypeVar(3)) + let _ = ctx.is_type_compatible(TypeVar(3), TypeVar(0)) + let _ = ctx.is_type_compatible(TypeVar(2), Tuple([Double, Int])) + let t0 = ctx.type_vars.get(0).unwrap() + let t1 = ctx.type_vars.get(1).unwrap() + let t2 = ctx.type_vars.get(2).unwrap() + let t3 = ctx.type_vars.get(3).unwrap() + assert_true(t0 is Tuple([Double, Int])) + assert_true(t1 is Tuple([Double, Int])) + assert_true(t2 is Tuple([Double, Int])) + assert_true(t3 is Tuple([Double, Int])) +} + +///| +test "Simple Atom Expression Type Check" { + let code = "42 3.14 true x y" + let ctx = Context::new() + ctx.type_env.set("x", { kind: Double, mutable: false }) + ctx.type_env.set("y", { kind: Bool, mutable: false }) + // Parse + let tokens = @parser.tokenize(code) + // Type of 42 + let (a, tok_view) = @parser.parse_expr(tokens[:]) + let a = ctx.check_atom_expr(a) + assert_true(a.ty is Int) + // Type of 3.14 + let (a, tok_view) = @parser.parse_expr(tok_view) + let a = ctx.check_atom_expr(a) + assert_true(a.ty is Double) + // Type of true + let (a, tok_view) = @parser.parse_expr(tok_view) + let a = ctx.check_atom_expr(a) + assert_true(a.ty is Bool) + // Type of x + let (a, tok_view) = @parser.parse_expr(tok_view) + let a = ctx.check_atom_expr(a) + assert_true(a.ty is Double) + // Type of y + let (a, _) = @parser.parse_expr(tok_view) + let a = ctx.check_atom_expr(a) + assert_true(a.ty is Bool) +} + +///| +test "Simple Apply Expression Type Check" { + let code = + #|42 3.14 true x y + let ctx = Context::new() + ctx.type_env.set("x", { kind: Double, mutable: false }) + ctx.type_env.set("y", { kind: Bool, mutable: false }) + // Parse + let tokens = @parser.tokenize(code) + // Type of 42 + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tokens[:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Int) + // Type of 3.14 + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Double) + // Type of true + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Bool) + // Type of x + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Double) + // Type of y + let (a, _) = @parser.parse_get_or_apply_level_expr(tok_view) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Bool) +} + +///| +test "Simple Expression Type Check" { + let code = + #|42 ; 3.14 ; true ; + #|!true ; !false ; !x ; + #|-42 ; -3.14 ; -y ; + #|1 + 3; 4 - 5 ; 6 * 7; 8.0 / 2.0; 9 % 4; + #|1 > 2; 42.0 >= y; x == true ; + let ctx = Context::new() + ctx.type_env.set("x", { kind: Bool, mutable: false }) + ctx.type_env.set("y", { kind: Double, mutable: false }) + // Parse + let tokens = @parser.tokenize(code) + // Type of 42 + let (a, tok_view) = @parser.parse_expr(tokens[:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of 3.14 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Double) + // Type of true + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) + // Type of !true + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) + // Type of !false + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) + // Type of !x + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) + // Type of -42 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of -3.14 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Double) + // Type of -y + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Double) + // Type of 1 + 3 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of 4 - 5 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of 6 * 7 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of 8.0 / 2.0 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Double) + // Type of 9 % 4 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + //// Type of 1 > 2 + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) + //// Type of 42 == y + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) + //// Type of x == true + let (a, _) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) +} + +///| +test "Atom Expression Type Check" { + let code = + #|[1, 2+3, 3+4+5, 6*7+8] (1+3, !y) Array::make(5, 0) [] + let ctx = Context::new() + ctx.type_env.set("x", { kind: Double, mutable: false }) + ctx.type_env.set("y", { kind: Bool, mutable: false }) + // Parse + let tokens = @parser.tokenize(code) + // Type of [1, 2+3, 3+4+5, 6*7+8] + let (a, tok_view) = @parser.parse_value_level_expr(tokens[:]) + let a = ctx.check_atom_expr(a) + assert_true(a.ty is Array(Int)) + // Type of (1+3, !y) + let (b, tok_view) = @parser.parse_value_level_expr(tok_view) + let b = ctx.check_atom_expr(b) + assert_true(b.ty is Tuple([Int, Bool])) + // Type of Array::make(5, 0) + let (c, tok_view) = @parser.parse_value_level_expr(tok_view) + let c = ctx.check_atom_expr(c) + assert_true(c.ty is Array(Int)) + // Type of [] + let (d, _) = @parser.parse_value_level_expr(tok_view) + let d = ctx.check_atom_expr(d) + assert_true(d.ty is Array(TypeVar(_))) +} + +///| +test "Apply Expression Type Check" { + let code = + #|arr[3] ; fact(5) ; + #|mat[3][4] ; sum(arr); max(x, y); + #|arr.push(10); mat.length(); + let ctx = Context::new() + ctx.type_env.set("arr", { kind: Array(Int), mutable: false }) + ctx.type_env.set("mat", { kind: Array(Array(Int)), mutable: false }) + ctx.type_env.set("fact", { kind: Function([Int], Int), mutable: false }) + ctx.type_env.set("sum", { kind: Function([Array(Int)], Int), mutable: false }) + // Note: max is local function without type annotations + // like: fn max(a, b) { ... } + ctx.type_env.set("max", { + kind: Function([TypeVar(0), TypeVar(0)], TypeVar(0)), + mutable: false, + }) + ctx.type_env.set("x", { kind: Double, mutable: true }) + ctx.type_env.set("y", { kind: Double, mutable: true }) + ctx.type_vars.set(0, TypeVar(0)) + // Parse + let tokens = @parser.tokenize(code) + // Type of arr[3] + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tokens[:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Int) + // Type of fact(5) + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Int) + // Type of mat[3][4] + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Int) + // Type of sum(arr) + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Int) + // Type of max(x, y) + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Double) + // Type of arr.push(10) + let (a, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view[1:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Unit) + // Type of mat.length() + let (a, _) = @parser.parse_get_or_apply_level_expr(tok_view[1:]) + let a = ctx.check_apply_expr(a) + assert_true(a.ty is Int) +} + +///| +test "Let Mut Stmt Type Check" { + let code = + #|let mut x : Double = 42.0; + #|let mut y = 10; + #|let mut z = []; + #|let _ = z.push(3.14); + #|let _ = z.push(false); // Should fail + let ctx = Context::new() + // Parse + let tokens = @parser.tokenize(code) + // Type check for `let mut x : Double = 42.0;` + let (let_mut_stmt1, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:]) + let checked_let_mut_stmt1 = ctx.check_let_mut_stmt(let_mut_stmt1) + assert_true(checked_let_mut_stmt1.ty.kind is Double) + // Type check for `let mut y = 10;` + let (let_mut_stmt2, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let checked_let_mut_stmt2 = ctx.check_let_mut_stmt(let_mut_stmt2) + assert_true(checked_let_mut_stmt2.ty.kind is Int) + // Type check for `let mut z = [];` + let (let_mut_stmt3, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let checked_let_mut_stmt3 = ctx.check_let_mut_stmt(let_mut_stmt3) + assert_true(checked_let_mut_stmt3.ty.kind is Array(TypeVar(_))) + // Type check for `let _ = z.push(3.14);` + // Parse Let Stmt + let (let_stmt4, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let checked_let_stmt4 = ctx.check_let_stmt(let_stmt4) + assert_true(checked_let_stmt4.ty is Unit) + assert_true(ctx.lookup_type("z") is Some({ kind: Array(Double), .. })) + // Type check for `let _ = z.push("Hello");` - Should fail + let (let_stmt5, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let check_result = try? ctx.check_let_stmt(let_stmt5) + assert_true(check_result is Err(_)) +} + +///| +test "Top Let Stmt Type Check" { + let code = + #|let x : Int = 42; + #|let y = true; + let ctx = Context::new() + // Parse + let tokens = @parser.tokenize(code) + let { top_lets, .. } = @parser.parse_program(tokens) + // Type check for `let x : Int = 42;` + let checked_top_let1 = ctx.check_top_let(top_lets["x"]) + assert_true(checked_top_let1.ty.kind is Int) + let ty1 = ctx.lookup_type("x") + assert_true(ty1 is Some(t) && t.kind is Int && t.mutable == false) + + // Type check for `let y = true;` + let checked_top_let2 = ctx.check_top_let(top_lets["y"]) + assert_true(checked_top_let2.ty.kind is Bool) + let ty2 = ctx.lookup_type("y") + assert_true(ty2 is Some(t) && t.kind is Bool && t.mutable == false) +} + +///| +test "Let Stmt Type Check" { + let code = + #|let x : Double = 42.0; + #|let (x, y) = (1, 2); + #|let z = []; + #|let mat = [[1, 2], z]; + #|let _ = z.push(true); + let ctx = Context::new() + // Parse + let tokens = @parser.tokenize(code) + // Type check for `let x : Int = 42;` + let (let_stmt1, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:]) + let checked_let_stmt1 = ctx.check_let_stmt(let_stmt1) + assert_true(checked_let_stmt1.ty is Double) + // Type check for `let (x, y) = (1, 2);` + let (let_stmt2, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let checked_let_stmt2 = ctx.check_let_stmt(let_stmt2) + assert_true(checked_let_stmt2.ty is Tuple([Int, Int])) + assert_true(ctx.lookup_type("x") is Some({ kind: Int, .. })) + assert_true(ctx.lookup_type("y") is Some({ kind: Int, .. })) + // Type check for `let z = [];` + let (let_stmt3, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let checked_let_stmt3 = ctx.check_let_stmt(let_stmt3) + assert_true(checked_let_stmt3.ty is Array(TypeVar(_))) + // Type check for `let mat = [[1, 2], z];` + let (let_stmt4, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let checked_let_stmt4 = ctx.check_let_stmt(let_stmt4) + assert_true(checked_let_stmt4.ty is Array(Array(Int))) + assert_true(ctx.lookup_type("mat") is Some({ kind: Array(Array(Int)), .. })) + let zty = ctx.lookup_type("z") + assert_true( + ctx.lookup_type("z") is Some({ kind: Array(Int), .. }), + msg="Type of z is \{zty}", + ) + // Type check for `let _ = z.push(true);` + // Should fail + let (let_stmt5, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let check_result = try? ctx.check_let_stmt(let_stmt5) + assert_true(check_result is Err(_)) +} + +///| +test "Assign Stmt Type Check" { + let code = + #|let mut x = 10; + #|x = 5; + #|let arr = [1, 2, 3]; + #|arr[0] = 10; + #|x = x + true; // Should fail + #|let mut a = xxx; // xxx's type is typevar + #|a = a + 33; // `a` and `xxx` should be inferred as Int + let ctx = Context::new() + // Parse + let tokens = @parser.tokenize(code) + // Type of let mut x = 10; x = 5; + let (s, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:]) + let _ = ctx.check_let_mut_stmt(s) + let (a, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let a = ctx.check_assign_stmt(a) + assert_true(a.left_value.ty.kind is Int) + assert_true(a.expr.ty is Int) + // Type of `let arr = [1, 2, 3]; arr[0] = 10;` + let (s, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let _ = ctx.check_let_stmt(s) + let (a, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let a = ctx.check_assign_stmt(a) + assert_true(a.left_value.ty.kind is Int) + assert_true(a.left_value.ty.mutable is true) + assert_true(a.expr.ty is Int) + // Type of `x += true;` should fail + let (a, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let t = try? ctx.check_assign_stmt(a) + assert_true(t is Err(_)) + // Type of `let mut a = xxx; a += 33;` + ctx.type_vars.set(0, TypeVar(0)) + ctx.type_env.set("xxx", { kind: TypeVar(0), mutable: false }) + let (s, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let _ = ctx.check_let_mut_stmt(s) + let (a, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let _ = ctx.check_assign_stmt(a) + assert_true(ctx.lookup_type("a") is Some({ kind: Int, .. })) + assert_true(ctx.lookup_type("xxx") is Some({ kind: Int, .. })) +} + +///| +test "Block Expr TypeCheck Test" { + let code = + #|let arr = []; + #| + #|{ + #| let (x, y) = (1, 2); + #| let mut z = x + y; + #| z = z + 10; + #| arr.push(z); + #|} + #| + #|{ + #| let (x, y) = (1.0, 2.0); + #| let z = x * y; + #| let arr = [1.0, z, 5.0]; + #| arr[2] = arr[1]; + #| arr[2] + #|} + #| + #|{ arr[0] = 1.0; } // should cause type error + let ctx = Context::new() + // Parse + let tokens = @parser.tokenize(code) + let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:]) + let _ = ctx.check_let_stmt(let_stmt) + // TypeCheck first block expr + let (block_expr1, tok_view) = @parser.parse_block_expr(tok_view) + let checked_block1 = ctx.check_block_expr(block_expr1) + assert_true(checked_block1.stmts.length() is 5) // Lilunar adds an implicit Expr(Unit) at the end + assert_true(checked_block1.ty is Unit) + // TypeCheck second block expr + let (block_expr2, tok_view) = @parser.parse_block_expr(tok_view) + let checked_block2 = ctx.check_block_expr(block_expr2) + assert_true(checked_block2.stmts.length() is 5) + assert_true(checked_block2.ty is Double) + // TypeCheck third block expr (should raise type error) + assert_true(ctx.lookup_type("arr") is Some({ kind: Array(Int), .. })) + let (block_expr3, _) = @parser.parse_block_expr(tok_view) + let checked_block3 = try? ctx.check_block_expr(block_expr3) + assert_true(checked_block3 is Err(_)) +} + +///| +test "Expr TypeCheck Test" { + let code = + #|arr[3] + arr[5] ; fact(5) + fib(10) ; + #|mat.data[3][4] ; sum(arr); 3.0 > max(x, y); + let ctx = Context::new() + ctx.struct_defs.set("Matrix", { + name: "Matrix", + fields: [{ name: "data", ty: { kind: Array(Array(Int)), mutable: false } }], + }) + ctx.type_env.set("arr", { kind: Array(Int), mutable: false }) + ctx.type_env.set("mat", { kind: Struct("Matrix"), mutable: false }) + // Note: Matrix struct has a field `data` of type Array(Array(Int)) + ctx.type_env.set("fact", { kind: Function([TypeVar(0)], Int), mutable: false }) + ctx.type_env.set("fib", { kind: Function([Int], Int), mutable: false }) + ctx.type_env.set("sum", { kind: Function([Array(Int)], Int), mutable: false }) + // Note: max is local function without type annotations + ctx.type_env.set("max", { + kind: Function([TypeVar(1), TypeVar(2)], TypeVar(3)), + mutable: false, + }) + ctx.type_env.set("x", { kind: Double, mutable: true }) + ctx.type_env.set("y", { kind: Double, mutable: true }) + ctx.type_vars.set(0, TypeVar(0)) + ctx.type_vars.set(1, TypeVar(1)) + ctx.type_vars.set(2, TypeVar(1)) + ctx.type_vars.set(3, TypeVar(1)) + + // Parse + let tokens = @parser.tokenize(code) + // Type of arr[3] + arr[5] + let (a, tok_view) = @parser.parse_expr(tokens[:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of fact(5) + fib(10) + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of mat.data[3][4] + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of sum(arr) + let (a, tok_view) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Int) + // Type of 3.0 > max(x, y) + let (a, _) = @parser.parse_expr(tok_view[1:]) + let a = ctx.check_expr(a) + assert_true(a.ty is Bool) + assert_true(ctx.lookup_type("x") is Some({ kind: Double, .. })) + assert_true(ctx.lookup_type("y") is Some({ kind: Double, .. })) +} + +///| +test "If Expr TypeCheck Test" { + let code = + #|let arr = []; + #| + #|if a > b { + #| arr.push(1); + #|} else { + #| arr.push(2); + #|} + #| + #|if a < b { + #| arr.push(a); + #| a + #|}else if a == b { + #| arr.push(1); + #| a - b + #|} else { + #| arr.push(b); + #| b + #|} + #| + let ctx = Context::new() + ctx.type_env.set("a", { kind: Int, mutable: false }) + ctx.type_env.set("b", { kind: Int, mutable: false }) + // Parse + let tokens = @parser.tokenize(code) + let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:]) + let _ = ctx.check_let_stmt(let_stmt) + // TypeCheck first if expr + let (if_expr1, tok_view) = @parser.parse_if_expr(tok_view) + let checked_if1 = ctx.check_if_expr(if_expr1) + assert_true(checked_if1.ty is Unit) + assert_true(ctx.lookup_type("arr") is Some({ kind: Array(Int), .. })) + // TypeCheck second if expr + let (if_expr2, _) = @parser.parse_if_expr(tok_view) + let checked_if2 = ctx.check_if_expr(if_expr2) + assert_true(checked_if2.ty is Int) + assert_true(ctx.lookup_type("a") is Some({ kind: Int, .. })) + assert_true(ctx.lookup_type("b") is Some({ kind: Int, .. })) +} + +///| +test "While Stmt TypeCheck Test" { + let code = + #|let mut i = 0; + #|while i < 10 { + #| i = i + 1; + #|} + let ctx = Context::new() + // Parse + let tokens = @parser.tokenize(code) + let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens[:]) + let _ = ctx.check_let_mut_stmt(let_stmt) + let (while_stmt, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let _ = ctx.check_while_stmt(while_stmt) + +} + +///| +test "Local Function Type Check Test" { + let code = + #|fn max(a, b) { + #| if a > b { a } else { b } + #|} + #| + #|let _ = max(10, 20); + #|let _ = max(1.0, 2.0); // This should raise a type error + let ctx = Context::new() + // Parse + let tokens = @parser.tokenize(code) + guard @parser.parse_stmt_or_expr_end(tokens[:]) + is (LocalFunction(local_func), false, tok_view) + // TypeCheck + let _ = ctx.check_local_function(local_func) + // parse and check let statement + let (let_stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let let_stmt = ctx.check_let_stmt(let_stmt) + assert_true(let_stmt.ty is Int) + let (let_stmt2, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let let_stmt2 = try? ctx.check_let_stmt(let_stmt2) + assert_true(let_stmt2 is Err(_)) +} + +///| +test "Struct Construct TypeCheck Test" { + let code = + #|struct Point { x: Int; y: Int; } + #|fn main { + #| let p1 = Point::{ x: 1, y: 2 }; + #| let p2 = Point::{ x: 1, y: true }; + #|} + let tokens = @parser.tokenize(code) + let ctx = Context::new() + let program = @parser.parse_program(tokens) + let struct_def = program.struct_defs["Point"] + guard program.top_functions["main"].body is [let_stmt1, let_stmt2, ..] + + // 1. Parse and check the struct definition to populate the context + ctx.struct_defs.set("Point", ctx.check_struct_def(struct_def)) // Assumes check_struct_def works + + // 2. Parse and check the valid `let p1 = ...` statement + let _ = ctx.check_let_stmt(let_stmt1) + let p1_type = ctx.lookup_type("p1") + assert_true(p1_type is Some({ kind: Struct("Point"), .. })) + + // 3. Parse and check the invalid `let p2 = ...` statement + let result = try? ctx.check_let_stmt(let_stmt2) + assert_true(result is Err(_)) +} + +///| +test "Top Function TypeCheck Test" { + let code = + #|fn fib(n : Int) -> Int { + #| if n <= 1 { + #| return n; + #| } else { + #| return fib(n - 1) + fib(n - 2); + #| }; + #| 0 + #|} + let ctx = Context::new() + ctx.func_types.set("fib", Function([Int], Int)) + ctx.type_env.set("fib", { kind: Function([Int], Int), mutable: false }) + // parse + let tokens = @parser.tokenize(code) + let program = @parser.parse_program(tokens) + let _ = ctx.check_top_function(program.top_functions["fib"]) + +} + +///| +test "Program TypeCheck Test" { + let code = + #|let a = 3; + #|let b = 4; + #|fn fold(arr: Array[Int], f: (Int, Int) -> Int, init: Int) -> Int { + #| let mut result = init; + #| let mut i = 0; + #| while i < arr.length() { + #| result = f(result, arr[i]); + #| } + #| result + #|} + #| + #|fn main { + #| fn max(a, b) { if a > b { a } else { b } } + #| fn min(a, b) { if a < b { a } else { b } } + #| let numbers = [a, 1, b, 1, 5, 9, 2, 6, 5]; + #| let maximum = fold(numbers, max, -1000); + #| let minimum = fold(numbers, min, 1000); + #| let max_min_diff = maximum - minimum; + #| print_int(max_min_diff); + #|} + let ctx = Context::new() + ctx.type_env.set("print_int", { kind: Function([Int], Unit), mutable: false }) + ctx.func_types.set("print_int", Function([Int], Unit)) + // Type check the program + let tokens = @parser.tokenize(code) + let program = @parser.parse_program(tokens) + let _ = ctx.check_program(program) + +} + +///| +test "TypeCheck Test" { + let code = + #|let a = 3; + #|let b = 4; + #|fn fold(arr: Array[Int], f: (Int, Int) -> Int, init: Int) -> Int { + #| let mut result = init; + #| let mut i = 0; + #| while i < arr.length() { + #| result = f(result, arr[i]); + #| } + #| result + #|} + #| + #|fn main { + #| fn max(a, b) { if a > b { a } else { b } } + #| fn min(a, b) { if a < b { a } else { b } } + #| let numbers = [a, 1, b, 1, 5, 9, 2, 6, 5]; + #| let maximum = fold(numbers, max, -1000); + #| let minimum = fold(numbers, min, 1000); + #| let max_min_diff = maximum - minimum; + #|} + let tokens = @parser.tokenize(code) + let program = @parser.parse_program(tokens) + let program = typecheck(program) + let program_str = program.to_string() + assert_false(program_str.contains("TypeVar")) +} diff --git a/src/typecheck/typechecker.mbt b/src/typecheck/typechecker.mbt new file mode 100644 index 0000000..3529b97 --- /dev/null +++ b/src/typecheck/typechecker.mbt @@ -0,0 +1,241 @@ +///| +pub fn typecheck(program : @parser.Program) -> Program raise TypeCheckError { + let ctx = Context::new() + let checked_program = ctx.check_program(program) + ctx.substitute_type_var(checked_program) +} + +///| +pub fn Context::substitute_type_var( + self : Context, + program : Program, +) -> Program { + for name, top_let in program.top_lets { + program.top_lets.set(name, { + name, + expr: self.substitute_type_var_for_expr(top_let.expr), + ty: { + kind: self.deref_type_var(top_let.ty.kind), + mutable: top_let.ty.mutable, + }, + }) + } + for name, top_func in program.top_functions { + program.top_functions.set(name, { + fname: name, + param_list: top_func.param_list.map(param => Param::{ + name: param.name, + ty: self.deref_type_var(param.ty), + }), + ret_ty: self.deref_type_var(top_func.ret_ty), + body: self.substitute_type_var_for_block_expr(top_func.body), + }) + } + program +} + +///| +pub fn Context::substitute_type_var_for_expr( + self : Context, + expr : Expr, +) -> Expr { + { + kind: match expr.kind { + Or(left, right) => + Or( + self.substitute_type_var_for_expr(left), + self.substitute_type_var_for_expr(right), + ) + And(left, right) => + And( + self.substitute_type_var_for_expr(left), + self.substitute_type_var_for_expr(right), + ) + MulDivRem(op, left, right) => + MulDivRem( + op, + self.substitute_type_var_for_expr(left), + self.substitute_type_var_for_expr(right), + ) + AddSub(op, left, right) => + AddSub( + op, + self.substitute_type_var_for_expr(left), + self.substitute_type_var_for_expr(right), + ) + Compare(op, left, right) => + Compare( + op, + self.substitute_type_var_for_expr(left), + self.substitute_type_var_for_expr(right), + ) + NegExpr(inner) => NegExpr(self.substitute_type_var_for_expr(inner)) + NotExpr(inner) => NotExpr(self.substitute_type_var_for_expr(inner)) + BlockExpr({ stmts, ty }) => + BlockExpr({ + stmts: stmts.map(stmt => { + kind: match stmt.kind { + ReturnStmt(expr) => + ReturnStmt(self.substitute_type_var_for_expr(expr)) + ExprStmt(expr) => + ExprStmt(self.substitute_type_var_for_expr(expr)) + WhileStmt({ cond, body }) => + WhileStmt({ + cond: self.substitute_type_var_for_expr(cond), + body: self.substitute_type_var_for_block_expr(body), + }) + AssignStmt({ left_value, expr }) => + AssignStmt({ + left_value: self.substitute_type_var_for_left_value( + left_value, + ), + expr: self.substitute_type_var_for_expr(expr), + }) + LetStmt({ pattern, ty, expr }) => + LetStmt({ + pattern: self.substitute_type_var_for_let_pattern(pattern), + ty: self.deref_type_var(ty), + expr: self.substitute_type_var_for_expr(expr), + }) + LetMutStmt({ name, ty, expr }) => + LetMutStmt({ + name, + ty: { + kind: self.deref_type_var(ty.kind), + mutable: ty.mutable, + }, + expr: self.substitute_type_var_for_expr(expr), + }) + LocalFunction({ fname, param_list, ret_ty, body }) => + LocalFunction({ + fname, + param_list: param_list.map(param => ( + param.0, + { + kind: self.deref_type_var(param.1.kind), + mutable: param.1.mutable, + }, + )), + ret_ty: { + kind: self.deref_type_var(ret_ty.kind), + mutable: ret_ty.mutable, + }, + body: self.substitute_type_var_for_block_expr(body), + }) + }, + }), + ty: self.deref_type_var(ty), + }) + ApplyExpr(apply) => + ApplyExpr(self.substitute_type_var_for_apply_expr(apply)) + IfExpr({ cond, then_block, else_block, ty }) => + IfExpr({ + cond: self.substitute_type_var_for_expr(cond), + then_block: self.substitute_type_var_for_block_expr(then_block), + else_block: else_block.map(eb => self.substitute_type_var_for_expr(eb)), + ty: self.deref_type_var(ty), + }) + }, + ty: self.deref_type_var(expr.ty), + } +} + +///| +fn Context::substitute_type_var_for_block_expr( + self : Context, + block : BlockExpr, +) -> BlockExpr { + guard self.substitute_type_var_for_expr({ kind: BlockExpr(block), ty: Unit }).kind // only use kind + is BlockExpr(block) + block +} + +///| +fn Context::substitute_type_var_for_left_value( + self : Context, + left_value : LeftValue, +) -> LeftValue { + { + kind: match left_value.kind { + Ident(name) => Ident(name) + ArrayAccess(base, index_expr) => + ArrayAccess( + self.substitute_type_var_for_left_value(base), + self.substitute_type_var_for_expr(index_expr), + ) + FieldAccess(base, field_name) => + FieldAccess(self.substitute_type_var_for_left_value(base), field_name) + }, + ty: { + kind: self.deref_type_var(left_value.ty.kind), + mutable: left_value.ty.mutable, + }, + } +} + +///| +fn Context::substitute_type_var_for_let_pattern( + self : Context, + pattern : Pattern, +) -> Pattern { + { + kind: match pattern.kind { + Ident(name) => Ident(name) + Wildcard => Wildcard + Tuple(patterns) => + Tuple(patterns.map(p => self.substitute_type_var_for_let_pattern(p))) + }, + } +} + +///| +fn Context::substitute_type_var_for_apply_expr( + self : Context, + apply_expr : ApplyExpr, +) -> ApplyExpr { + { + kind: match apply_expr.kind { + AtomExpr({ kind, ty }) => + AtomExpr({ + kind: match kind { + Int(v) => Int(v) + Double(v) => Double(v) + Bool(v) => Bool(v) + Unit => Unit + Ident(name) => Ident(name) + Tuple(elems) => + Tuple(elems.map(elem => self.substitute_type_var_for_expr(elem))) + Array(elems) => + Array(elems.map(elem => self.substitute_type_var_for_expr(elem))) + ArrayMake(size_expr, init_expr) => + ArrayMake( + self.substitute_type_var_for_expr(size_expr), + self.substitute_type_var_for_expr(init_expr), + ) + StructConstruct({ name, fields }) => + StructConstruct({ + name, + fields: fields.map(field => { + let (field_name, field_expr) = field + (field_name, self.substitute_type_var_for_expr(field_expr)) + }), + }) + }, + ty: self.deref_type_var(ty), + }) + ArrayAccess(base, index_expr) => + ArrayAccess( + self.substitute_type_var_for_apply_expr(base), + self.substitute_type_var_for_expr(index_expr), + ) + FieldAccess(base, field_name) => + FieldAccess(self.substitute_type_var_for_apply_expr(base), field_name) + Call(callee, args) => + Call( + self.substitute_type_var_for_apply_expr(callee), + args.map(arg => self.substitute_type_var_for_expr(arg)), + ) + }, + ty: self.deref_type_var(apply_expr.ty), + } +} diff --git a/src/typecheck/typedef.mbt b/src/typecheck/typedef.mbt new file mode 100644 index 0000000..571f547 --- /dev/null +++ b/src/typecheck/typedef.mbt @@ -0,0 +1,163 @@ +///| +pub(all) struct Type { + kind : TypeKind + mutable : Bool +} derive(Show) + +///| +pub(all) enum TypeKind { + Unit + Bool + Int + Double + Tuple(Array[TypeKind]) + Array(TypeKind) + Function(Array[TypeKind], TypeKind) + Struct(String) + Any + TypeVar(Int) +} derive(Eq, Hash) + +///| +pub fn Context::check_parser_type( + self : Context, + ty : @parser.Type, + mutable? : Bool = false, +) -> Type raise TypeCheckError { + match ty { + UserDefined(name) => { + if self.type_env.get("$Generic$\{name}") is Some(t) { + return t + } + if !self.struct_defs.contains(name) { + raise TypeCheckError("Undefined type: \{name}") + } + { kind: Struct(name), mutable } + } + Function(param_types, return_type) => { + let param_types = param_types.map(param_type => self.check_parser_type( + param_type, + ).kind) + let return_type = self.check_parser_type(return_type).kind + { kind: Function(param_types, return_type), mutable } + } + Tuple(types) => + { + kind: Tuple(types.map(type_ => self.check_parser_type(type_).kind)), + mutable, + } + Array(elem_type) => + { kind: Array(self.check_parser_type(elem_type).kind), mutable } + Double => { kind: Double, mutable } + Int => { kind: Int, mutable } + Bool => { kind: Bool, mutable } + Unit => { kind: Unit, mutable } + Generic(_) => ... + } +} + +///| +pub fn Context::deref_type_var(self : Context, ty : TypeKind) -> TypeKind { + match ty { + TypeVar(r) if self.type_vars[r] == TypeVar(r) => ty + TypeVar(r) => self.deref_type_var(self.type_vars[r]) // 要求不能有环 + Tuple(elem_types) => + Tuple(elem_types.map(typekind => self.deref_type_var(typekind))) + Array(elem_type) => Array(self.deref_type_var(elem_type)) + Function(param_types, return_type) => + Function( + param_types.map(typekind => self.deref_type_var(typekind)), + self.deref_type_var(return_type), + ) + _ => ty + } +} + +///| +pub fn Context::is_type_compatible( + self : Context, + a : TypeKind, + b : TypeKind, +) -> Bool { + fn set_type_vars(self : Context, id : Int, target : TypeKind) -> Unit { + if target == TypeVar(id) { + return + } + for var_id, var_type in self.type_vars { + if var_type == TypeVar(id) { + self.type_vars[var_id] = target + set_type_vars(self, var_id, target) + } + } + } + + let deref_a = self.deref_type_var(a) + let deref_b = self.deref_type_var(b) + match (deref_a, deref_b) { + (TypeVar(id_a), TypeVar(id_b)) => { + if id_a != id_b { + set_type_vars(self, id_a, TypeVar(id_b)) + } + true + } + (TypeVar(id), other) | (other, TypeVar(id)) => { + set_type_vars(self, id, other) + true + } + (Unit, Unit) => true + (Bool, Bool) => true + (Int, Int) => true + (Double, Double) => true + (Tuple(types_a), Tuple(types_b)) => { + if types_a.length() != types_b.length() { + return false + } + let mut result = true + for i in 0.. self.is_type_compatible(elem_a, elem_b) + (Function(params_a, ret_a), Function(params_b, ret_b)) => { + if params_a.length() != params_b.length() { + return false + } + let mut result = true + for i in 0.. name_a == name_b + (Any, _) | (_, Any) => true + _ => false + } +} + +///| +pub impl Show for TypeKind with output(self, logger) { + let s = match self { + Unit => "Unit" + Bool => "Bool" + Int => "Int" + Double => "Double" + Tuple(types) => { + let inner = types.map(typekind => typekind.to_string()).join(", ") + "(\{inner})" + } + Array(elem_type) => "Array[\{elem_type.to_string()}]" + Function(param_types, return_type) => { + let params = param_types.map(typekind => typekind.to_string()).join(", ") + "(\{params}) -> \{return_type.to_string()}" + } + Struct(name) => "\{name}" + Any => "Any" + TypeVar(name) => "TypeVar(\{name})" + } + logger.write_string(s) +}