From b7cbbdfbba081617fb93f76153e635118df46503 Mon Sep 17 00:00:00 2001 From: Lil-Ran Date: Tue, 11 Nov 2025 05:28:39 +0800 Subject: [PATCH] feat: knf --- README.mbt.md | 4 +- batch_run.py | 2 +- minimoonbit.json | 2 +- moon.mod.json | 5 +- src/bin/main.mbt | 32 +- src/bin/moon.pkg.json | 3 +- src/knf/apply_expr.mbt | 150 ++++ src/knf/assign_stmt.mbt | 67 ++ src/knf/block.mbt | 51 ++ src/knf/closure.mbt | 83 ++ src/knf/context.mbt | 134 +++ src/knf/expr.mbt | 228 +++++ src/knf/function.mbt | 62 ++ src/knf/knf.mbt | 85 ++ src/knf/knf_test.mbt | 1417 ++++++++++++++++++++++++++++++++ src/knf/let_stmt.mbt | 75 ++ src/knf/moon.pkg.json | 8 + src/knf/stmt.mbt | 99 +++ src/knf/struct_def.mbt | 47 ++ src/knf/top_let.mbt | 29 + src/knf/type.mbt | 64 ++ src/typecheck/expr_atom.mbt | 8 + src/typecheck/expr_block.mbt | 13 +- src/typecheck/top_function.mbt | 7 +- src/typecheck/typechecker.mbt | 2 +- 25 files changed, 2654 insertions(+), 23 deletions(-) create mode 100644 src/knf/apply_expr.mbt create mode 100644 src/knf/assign_stmt.mbt create mode 100644 src/knf/block.mbt create mode 100644 src/knf/closure.mbt create mode 100644 src/knf/context.mbt create mode 100644 src/knf/expr.mbt create mode 100644 src/knf/function.mbt create mode 100644 src/knf/knf.mbt create mode 100644 src/knf/knf_test.mbt create mode 100644 src/knf/let_stmt.mbt create mode 100644 src/knf/moon.pkg.json create mode 100644 src/knf/stmt.mbt create mode 100644 src/knf/struct_def.mbt create mode 100644 src/knf/top_let.mbt create mode 100644 src/knf/type.mbt diff --git a/README.mbt.md b/README.mbt.md index 97154f4..dadb98e 100644 --- a/README.mbt.md +++ b/README.mbt.md @@ -1,6 +1,8 @@ # Lilunar -Lil-Ran's MiniMoonBit to RISC-V compiler for [MGPIC-2025](https://www.moonbitlang.cn/2025-mgpic-compiler). +🚧 **To be continued...** + +Lil-Ran's MiniMoonBit to LLVM IR compiler for [MGPIC-2025](https://www.moonbitlang.cn/2025-mgpic-compiler). It is not optimized yet. diff --git a/batch_run.py b/batch_run.py index 9f32079..5442e4b 100644 --- a/batch_run.py +++ b/batch_run.py @@ -32,7 +32,7 @@ async def main(): for file in os.listdir("contest-2025-data/test_cases/mbt"): if file.endswith(".mbt"): in_path = os.path.join("contest-2025-data/test_cases/mbt", file) - out_path = os.path.join("output/repo", file.replace(".mbt", ".s")) + out_path = os.path.join("output/repo", file.replace(".mbt", ".ll")) tasks.append(check_file(in_path, out_path)) await asyncio.gather(*tasks) diff --git a/minimoonbit.json b/minimoonbit.json index e0b25bf..1226118 100644 --- a/minimoonbit.json +++ b/minimoonbit.json @@ -1,3 +1,3 @@ { - "emit": "asm" + "emit": "llvm" } \ No newline at end of file diff --git a/moon.mod.json b/moon.mod.json index a195e84..ba04ab0 100644 --- a/moon.mod.json +++ b/moon.mod.json @@ -11,10 +11,9 @@ "keywords": [ "MiniMoonBit", "compiler", - "RISC-V", - "assembly" + "MGPIC-2025" ], - "description": "My MiniMoonBit to RISC-V compiler for MGPIC-2025.", + "description": "Lil-Ran's MiniMoonBit to LLVM IR compiler for MGPIC-2025.", "source": "src", "preferred-target": "wasm-gc" } \ No newline at end of file diff --git a/src/bin/main.mbt b/src/bin/main.mbt index 83bd70a..06ed2a6 100644 --- a/src/bin/main.mbt +++ b/src/bin/main.mbt @@ -42,7 +42,7 @@ fn main { in_file = Some(s) }, ( - #|Lilunar: Lil-Ran's experimental MiniMoonBit to RISC-V compiler for MGPIC-2025. + #|Lilunar: Lil-Ran's experimental MiniMoonBit to LLVM IR compiler for MGPIC-2025. #| #| usage: lilunar [options] #| @@ -61,7 +61,7 @@ fn main { } println_debug( ( - #| + #|Source: #|================================ $|\{contents} #|================================ @@ -82,11 +82,25 @@ fn main { println_debug("Type checking passed.") return } - - // let asm_string = ... - // if out_file.val == "-" { - // println(asm_string) - // } else { - // @fs.write_string_to_file?(out_file.val, asm_string).unwrap() - // } + let knf = @knf.knf_transform(program) catch { + @knf.KnfTransformError(msg) => + println_panic("KNF transformation error: \{msg}") + } + println_debug( + ( + #|KNF: + #|================================ + $|\{knf} + #|================================ + ), + ) + let output_string = knf.to_string() + if out_file.val == "-" { + println(output_string) + } else { + @fs.write_string_to_file(out_file.val, output_string) catch { + @fs.IOError(msg) => + println_panic("Failed to write to output file \{out_file.val}: \{msg}") + } + } } diff --git a/src/bin/moon.pkg.json b/src/bin/moon.pkg.json index 867b4cf..1aeac98 100644 --- a/src/bin/moon.pkg.json +++ b/src/bin/moon.pkg.json @@ -5,6 +5,7 @@ "Yoorkin/ArgParser", "moonbitlang/x/fs", "Lil-Ran/lilunar/parser", - "Lil-Ran/lilunar/typecheck" + "Lil-Ran/lilunar/typecheck", + "Lil-Ran/lilunar/knf" ] } \ No newline at end of file diff --git a/src/knf/apply_expr.mbt b/src/knf/apply_expr.mbt new file mode 100644 index 0000000..a6ff4e4 --- /dev/null +++ b/src/knf/apply_expr.mbt @@ -0,0 +1,150 @@ +///| +pub fn Context::apply_expr_to_knf( + self : Context, + apply_expr : @typecheck.ApplyExpr, +) -> (Array[KnfStmt], KnfExpr) raise KnfTransformError { + match apply_expr.kind { + AtomExpr(atom_expr) => self.atom_expr_to_knf(atom_expr) + ArrayAccess(array_expr, index_expr) => { + let stmts = [] + let (array_stmts, array_knf_expr) = self.apply_expr_to_knf(array_expr) + stmts.append(array_stmts) + let array_name = self.expr_to_knf_name( + array_knf_expr, + self.typekind_to_knf(array_expr.ty), + stmts, + ) + let (index_stmts, index_knf_expr) = self.expr_to_knf(index_expr) + stmts.append(index_stmts) + let index_name = self.expr_to_knf_name( + index_knf_expr, + self.typekind_to_knf(index_expr.ty), + stmts, + ) + let knf_expr = ArrayAccess(array_name, index_name) + (stmts, knf_expr) + } + FieldAccess(struct_expr, field_name) => { + let stmts = [] + let (struct_stmts, struct_knf_expr) = self.apply_expr_to_knf(struct_expr) + stmts.append(struct_stmts) + let struct_name = self.expr_to_knf_name( + struct_knf_expr, + self.typekind_to_knf(struct_expr.ty), + stmts, + ) + let knf_expr = FieldAccess(struct_name, field_name) + (stmts, knf_expr) + } + Call(callee_expr, arg_exprs) => { + let stmts = [] + let (callee_stmts, callee_knf_expr) = self.apply_expr_to_knf(callee_expr) + stmts.append(callee_stmts) + let callee_name = self.expr_to_knf_name( + callee_knf_expr, + self.typekind_to_knf(callee_expr.ty), + stmts, + ) + let arg_names = [] + for arg_expr in arg_exprs { + let (arg_stmts, arg_knf_expr) = self.expr_to_knf(arg_expr) + stmts.append(arg_stmts) + let arg_name = self.expr_to_knf_name( + arg_knf_expr, + self.typekind_to_knf(arg_expr.ty), + stmts, + ) + arg_names.push(arg_name) + } + let knf_expr = Call(callee_name, arg_names) + (stmts, knf_expr) + } + } +} + +///| +pub fn Context::atom_expr_to_knf( + self : Context, + atom_expr : @typecheck.AtomExpr, +) -> (Array[KnfStmt], KnfExpr) raise KnfTransformError { + let { kind, ty } = atom_expr + match kind { + Int(v) => ([], Int(v)) + Double(v) => ([], Double(v)) + Bool(v) => ([], Bool(v)) + Unit => ([], Unit) + Ident(s) => { + guard self.lookup_name(s) is Some((name, _)) else { + raise KnfTransformError("undefined identifier in atom expression: \{s}") + } + ([], Ident(name)) + } + Array(elems) => { + guard ty is Array(elem_ty) else { + raise KnfTransformError("expected array type in array literal") + } + let stmts = [] + let elem_names = [] + for elem in elems { + let (elem_stmts, elem_expr) = self.expr_to_knf(elem) + stmts.append(elem_stmts) + let elem_name = self.expr_to_knf_name( + elem_expr, + self.typekind_to_knf(elem.ty), + stmts, + ) + elem_names.push(elem_name) + } + (stmts, ArrayLiteral(self.typekind_to_knf(elem_ty), elem_names)) + } + Tuple(elems) => { + let stmts = [] + let elem_names = [] + for elem in elems { + let (elem_stmts, elem_expr) = self.expr_to_knf(elem) + stmts.append(elem_stmts) + let elem_name = self.expr_to_knf_name( + elem_expr, + self.typekind_to_knf(elem.ty), + stmts, + ) + elem_names.push(elem_name) + } + (stmts, TupleLiteral(elem_names)) + } + ArrayMake(size_expr, init_expr) => { + let stmts = [] + let (size_stmts, size_knf_expr) = self.expr_to_knf(size_expr) + stmts.append(size_stmts) + let size_name = self.expr_to_knf_name( + size_knf_expr, + self.typekind_to_knf(size_expr.ty), + stmts, + ) + let (init_stmts, init_knf_expr) = self.expr_to_knf(init_expr) + stmts.append(init_stmts) + let init_name = self.expr_to_knf_name( + init_knf_expr, + self.typekind_to_knf(init_expr.ty), + stmts, + ) + (stmts, ArrayMake(size_name, init_name)) + } + StructConstruct({ name, fields }) => { + let stmts = [] + let init_names = [] + for field in fields { + let (field_name, field_expr) = field + let (field_stmts, field_knf_expr) = self.expr_to_knf(field_expr) + stmts.append(field_stmts) + let field_name_knf = self.expr_to_knf_name( + field_knf_expr, + self.typekind_to_knf(field_expr.ty), + stmts, + ) + init_names.push((field_name, field_name_knf)) + } + (stmts, CreateStruct(name, init_names)) + } + } +} diff --git a/src/knf/assign_stmt.mbt b/src/knf/assign_stmt.mbt new file mode 100644 index 0000000..9a45817 --- /dev/null +++ b/src/knf/assign_stmt.mbt @@ -0,0 +1,67 @@ +///| +pub fn Context::assign_stmt_to_knf( + self : Context, + assign_stmt : @typecheck.AssignStmt, +) -> Array[KnfStmt] raise KnfTransformError { + let { left_value, expr } = assign_stmt + let stmts = [] + let (expr_stmts, expr_knf_expr) = self.expr_to_knf(expr) + stmts.append(expr_stmts) + let expr_ty = self.typekind_to_knf(expr.ty) + let (left_value_stmts, left_value_knf_expr) = self.left_value_to_knf( + left_value, + ) + stmts.append(left_value_stmts) + match left_value_knf_expr { + Ident(name) => stmts.push(Assign(name, expr_knf_expr)) + ArrayAccess(array_name, index_name) => + stmts.push(ArrayPut(array_name, index_name, expr_knf_expr)) + FieldAccess(struct_name, field_name) => + stmts.push( + StructFieldSet( + struct_name, + field_name, + self.expr_to_knf_name(expr_knf_expr, expr_ty, stmts), + ), + ) + _ => panic() // left_value_to_knf should not produce other kinds + } + stmts +} + +///| +pub fn Context::left_value_to_knf( + self : Context, + left_value : @typecheck.LeftValue, +) -> (Array[KnfStmt], KnfExpr) raise KnfTransformError { + match left_value.kind { + Ident(name_str) => { + let (name, _) = self + .lookup_name(name_str) + .or_error( + KnfTransformError("undefined identifier in left value: \{name_str}"), + ) + ([], Ident(name)) + } + ArrayAccess(array_expr, index_expr) => { + let stmts = [] + let (array_stmts, array_knf_expr) = self.left_value_to_knf(array_expr) + stmts.append(array_stmts) + let array_ty = self.type_to_knf(array_expr.ty) + let array_name = self.expr_to_knf_name(array_knf_expr, array_ty, stmts) + let (index_stmts, index_knf_expr) = self.expr_to_knf(index_expr) + stmts.append(index_stmts) + let index_ty = self.typekind_to_knf(index_expr.ty) + let index_name = self.expr_to_knf_name(index_knf_expr, index_ty, stmts) + (stmts, ArrayAccess(array_name, index_name)) + } + FieldAccess(struct_expr, field_name) => { + let stmts = [] + let (struct_stmts, struct_knf_expr) = self.left_value_to_knf(struct_expr) + stmts.append(struct_stmts) + let struct_ty = self.type_to_knf(struct_expr.ty) + let struct_name = self.expr_to_knf_name(struct_knf_expr, struct_ty, stmts) + (stmts, FieldAccess(struct_name, field_name)) + } + } +} diff --git a/src/knf/block.mbt b/src/knf/block.mbt new file mode 100644 index 0000000..2024909 --- /dev/null +++ b/src/knf/block.mbt @@ -0,0 +1,51 @@ +///| +pub(all) struct KnfBlock { + stmts : Array[KnfStmt] + ty : Type +} + +///| +pub fn Context::block_expr_to_knf( + self : Context, + expr : @typecheck.BlockExpr, +) -> KnfBlock raise KnfTransformError { + let stmts = [] + for stmt in expr.stmts { + stmts.append(self.stmt_to_knf(stmt)) + } + if stmts is [.., ExprStmt(Unit)] { + ignore(stmts.pop()) + } + let ty = self.typekind_to_knf(expr.ty) + { stmts, ty } +} + +///| +pub fn KnfBlock::to_string(self : KnfBlock, ident : Int) -> String { + let sb = StringBuilder::new() + sb.write_string("{\n") + for stmt in self.stmts { + sb.write_string(stmt.to_string(ident=ident + 2)) + sb.write_string("\n") + } + sb.write_string(" ".repeat(ident)) + sb.write_string("}") + sb.to_string() +} + +///| +pub fn KnfBlock::nested_to_string(self : KnfBlock) -> String { + let sb = StringBuilder::new() + sb.write_string("{") + for stmt in self.stmts { + sb.write_string(stmt.to_string(ident=0)) + sb.write_string(" ") + } + sb.write_string("}") + sb.to_string() +} + +///| +pub impl Show for KnfBlock with output(self, logger) { + logger.write_string(self.to_string(0)) +} diff --git a/src/knf/closure.mbt b/src/knf/closure.mbt new file mode 100644 index 0000000..ba8aace --- /dev/null +++ b/src/knf/closure.mbt @@ -0,0 +1,83 @@ +///| +pub(all) struct KnfClosure { + name : Name + params : Array[(Name, Type)] + ret_ty : Type + body : KnfBlock + captured_vars : Map[Name, Type] +} + +///| +pub fn Context::local_function_to_knf( + self : Context, + local_function : @typecheck.LocalFunction, +) -> KnfClosure raise KnfTransformError { + let { fname, param_list, ret_ty, body } = local_function + + // 1. 函数类型构造 + let knf_params = [] + let param_types = [] + for param in param_list { + let (param_name, param_ty) = param + let knf_param_ty = self.type_to_knf(param_ty) + knf_params.push((param_name, knf_param_ty)) + param_types.push(knf_param_ty) + } + let knf_ret_ty = self.type_to_knf(ret_ty) + let func_type = Function(param_types, knf_ret_ty) + let knf_func_name = self.add_new_name(fname, func_type) + + // 2. 进入函数作用域 + self.enter_scope() + + // 3. 参数处理和函数体转换 + let knf_param_names = [] + for param in knf_params { + let (param_name, param_ty) = param + let knf_param_name = self.add_new_name(param_name, param_ty) + knf_param_names.push((knf_param_name, param_ty)) + } + let knf_body = self.block_expr_to_knf(body) + + // 捕获的变量 + let captured_vars = Map::new() + for name in self.name_env.capture.keys() { + let ty = self.name_env.capture.get(name).unwrap() + captured_vars.set(name, ty) + } + + // 4. 退出作用域并构造闭包 + self.exit_scope() + { + name: knf_func_name, + params: knf_param_names, + ret_ty: knf_ret_ty, + body: knf_body, + captured_vars, + } +} + +///| +pub fn KnfClosure::to_string(self : KnfClosure, ident? : Int = 0) -> String { + let sb = StringBuilder::new() + let indent_str = " ".repeat(ident) + if !self.captured_vars.is_empty() { + sb.write_string("// Captured variables: \n") + for name, ty in self.captured_vars { + sb.write_string(indent_str) + sb.write_string("// - \{name} : \{ty}\n") + } + sb.write_string(indent_str) + sb.write_string("fn \{self.name}(") + } else { + sb.write_string("fn \{self.name}(") + } + let param_strs = self.params.map(name_ty => { + let (name, ty) = name_ty + "\{name} : \{ty}" + }) + sb.write_string(param_strs.join(", ")) + sb.write_string(") -> \{self.ret_ty} ") + sb.write_string(self.body.to_string(ident)) + sb.to_string() +} diff --git a/src/knf/context.mbt b/src/knf/context.mbt new file mode 100644 index 0000000..adc14d4 --- /dev/null +++ b/src/knf/context.mbt @@ -0,0 +1,134 @@ +///| +pub(all) suberror KnfTransformError String derive(Show) + +///| +pub(all) struct Name { + id : String + slot : Int +} derive(Hash, Eq) + +///| +pub fn Name::wildcard() -> Name { + Name::{ id: "_", slot: 0 } +} + +///| +pub impl Show for Name with output(self, logger) { + logger.write_string(self.id) + if self.slot > 0 { + logger.write_string("$\{self.slot}") + } +} + +///| +pub(all) struct Env { + local_ : Map[String, (Name, Type)] // defined in this scope + capture : Map[Name, Type] // captured from outer scopes + parent : Env? +} + +///| +pub fn Env::new(parent? : Env? = None) -> Env { + Env::{ local_: Map::new(), capture: Map::new(), parent } +} + +///| +pub fn Env::get_name_type(self : Env, name : Name) -> Type? { + let { id, .. } = name + match self.local_.get(id) { + Some((_, t)) => Some(t) + None => + match self.parent { + Some(p) => p.get_name_type(name) + None => None + } + } +} + +///| +pub fn Env::get(self : Env, name : String, no_capture? : Bool = false) -> Name? { + match self.local_.get(name) { + Some((n, _)) => Some(n) + None => + match self.parent { + Some(p) => + match p.get(name, no_capture~) { + Some(n) => { + if !no_capture { + let ty = p.get_name_type(n).unwrap() + self.capture.set(n, ty) + } + Some(n) + } + None => None + } + None => None + } + } +} + +///| +pub fn Env::set(self : Env, s : String, name : Name, ty : Type) -> Unit { + self.local_.set(s, (name, ty)) +} + +///| +pub(all) struct Context { + mut name_env : Env + capture : Array[Name] + globals : Map[String, Type] +} + +///| +pub fn Context::new() -> Context { + Context::{ name_env: Env::new(), capture: Array::new(), globals: Map::new() } +} + +///| +pub fn Context::lookup_name(self : Context, s : String) -> (Name, Type)? { + let local_ = self.name_env.get(s) + if local_ is Some(name) { + return Some((name, self.name_env.get_name_type(name).unwrap())) + } + let global = self.globals.get(s) + if global is Some(ty) { + return Some(({ id: s, slot: 0 }, ty)) + } + None +} + +///| +pub fn Context::enter_scope(self : Context) -> Unit { + let sub_env = Env::new(parent=Some(self.name_env)) + self.name_env = sub_env +} + +///| +pub fn Context::exit_scope(self : Context) -> Unit { + self.name_env = match self.name_env.parent { + Some(p) => p + None => self.name_env + } +} + +///| +pub fn Context::add_new_name(self : Context, s : String, ty : Type) -> Name { + match self.name_env.get(s, no_capture=true) { + Some({ id, slot }) => { + let name = Name::{ id, slot: slot + 1 } + self.name_env.set(s, name, ty) + name + } + None => { + let name = Name::{ id: s, slot: 0 } + self.name_env.set(s, name, ty) + name + } + } +} + +///| +pub fn Context::add_temp(self : Context, ty : Type) -> Name { + let temp_id = "tmp" + self.add_new_name(temp_id, ty) +} diff --git a/src/knf/expr.mbt b/src/knf/expr.mbt new file mode 100644 index 0000000..09e03d4 --- /dev/null +++ b/src/knf/expr.mbt @@ -0,0 +1,228 @@ +///| +pub(all) enum KnfExpr { + Unit + Int(Int) + Bool(Bool) + Double(Double) + Ident(Name) + Not(Name) + Neg(Name) + Binary(BinaryOp, Name, Name) + If(KnfExpr, KnfBlock, KnfBlock) + Block(KnfBlock) + Call(Name, Array[Name]) + ArrayAccess(Name, Name) + FieldAccess(Name, String) + TupleAccess(Name, Int) + CreateStruct(String, Array[(String, Name)]) + ArrayLiteral(Type, Array[Name]) + ArrayMake(Name, Name) + TupleLiteral(Array[Name]) +} + +///| +pub(all) enum BinaryOp { + Add // + + Sub // - + Mul // * + Div // / + Mod // % + Eq // == + NE // != + LT // < + GT // > + LE // <= + GE // >= + And // && + Or // || +} derive(Eq) + +///| +pub fn Context::expr_to_knf( + self : Context, + expr : @typecheck.Expr, +) -> (Array[KnfStmt], KnfExpr) raise KnfTransformError { + match expr.kind { + ApplyExpr(apply_expr) => self.apply_expr_to_knf(apply_expr) + NotExpr(inner) | NegExpr(inner) => { + let stmts = [] + let (inner_stmts, inner_expr) = self.expr_to_knf(inner) + stmts.append(inner_stmts) + let ty = self.typekind_to_knf(inner.ty) + let tmp_name = self.add_temp(ty) + stmts.push(Let(tmp_name, ty, inner_expr)) + let knf_expr = match expr.kind { + NotExpr(_) => Not(tmp_name) + NegExpr(_) => Neg(tmp_name) + _ => panic() + } + (stmts, knf_expr) + } + Compare(op, lhs, rhs) => { + let op = match op { + Equal => Eq + NotEqual => NE + Less => LT + Greater => GT + LessEqual => LE + GreaterEqual => GE + } + self.binary_expr_to_knf(op, lhs, rhs) + } + AddSub(op, lhs, rhs) => { + let op = match op { + Add => Add + Sub => Sub + } + self.binary_expr_to_knf(op, lhs, rhs) + } + MulDivRem(op, lhs, rhs) => { + let op = match op { + Mul => Mul + Div => Div + Rem => Mod + } + self.binary_expr_to_knf(op, lhs, rhs) + } + Or(lhs, rhs) => self.binary_expr_to_knf(Or, lhs, rhs) + And(lhs, rhs) => self.binary_expr_to_knf(And, lhs, rhs) + BlockExpr(block_expr) => { + let knf_block = self.block_expr_to_knf(block_expr) + ([], Block(knf_block)) + } + IfExpr(if_expr) => self.if_expr_to_knf(if_expr) + } +} + +///| +pub fn Context::expr_to_knf_name( + self : Context, + expr : KnfExpr, + ty : Type, + stmts : Array[KnfStmt], +) -> Name { + match expr { + Ident(name) => name + _ => { + let tmp_name = self.add_temp(ty) + stmts.push(Let(tmp_name, ty, expr)) + tmp_name + } + } +} + +///| +pub fn Context::binary_expr_to_knf( + self : Context, + op : BinaryOp, + lhs : @typecheck.Expr, + rhs : @typecheck.Expr, +) -> (Array[KnfStmt], KnfExpr) raise KnfTransformError { + let stmts = [] + let (lhs_stmts, lhs_expr) = self.expr_to_knf(lhs) + let (rhs_stmts, rhs_expr) = self.expr_to_knf(rhs) + stmts.append(lhs_stmts) + stmts.append(rhs_stmts) + let ty = self.typekind_to_knf(lhs.ty) + let lhs_name = self.expr_to_knf_name(lhs_expr, ty, stmts) + let rhs_name = self.expr_to_knf_name(rhs_expr, ty, stmts) + let knf_expr = Binary(op, lhs_name, rhs_name) + (stmts, knf_expr) +} + +///| +pub fn Context::if_expr_to_knf( + self : Context, + if_expr : @typecheck.IfExpr, +) -> (Array[KnfStmt], KnfExpr) raise KnfTransformError { + let stmts = [] + let (cond_stmts, cond_knf_expr) = self.expr_to_knf(if_expr.cond) + stmts.append(cond_stmts) + let then_block = self.block_expr_to_knf(if_expr.then_block) + let else_block = match if_expr.else_block { + None => { stmts: [], ty: Unit } + Some(expr) => { + let (else_stmts, nested_if_knf) = self.expr_to_knf(expr) + stmts.append(else_stmts) + match nested_if_knf { + Block(knf_block) => knf_block + _ => + { + stmts: [ExprStmt(nested_if_knf)], + ty: self.typekind_to_knf(expr.ty), + } + } + } + } + (stmts, If(cond_knf_expr, then_block, else_block)) +} + +///| +pub impl Show for BinaryOp with output(self, logger) { + let s = match self { + Add => "+" + Sub => "-" + Mul => "*" + Div => "/" + Mod => "%" + Eq => "==" + NE => "!=" + LT => "<" + GT => ">" + LE => "<=" + GE => ">=" + And => "&&" + Or => "||" + } + logger.write_string(s) +} + +///| +pub fn KnfExpr::to_string(self : KnfExpr, ident? : Int = 0) -> String { + match self { + Unit => "()" + Int(i) => i.to_string() + Bool(b) => b.to_string() + Double(d) => d.to_string() + Ident(name) => name.to_string() + Not(name) => "!\{name}" + Neg(name) => "-\{name}" + Binary(op, lhs, rhs) => "\{lhs} \{op} \{rhs}" + Call(func_name, args) => { + let args_strs = args.map(arg => arg.to_string()).join(", ") + "\{func_name}(\{args_strs})" + } + ArrayAccess(array_name, index_name) => "\{array_name}[\{index_name}]" + FieldAccess(struct_name, field_name) => "\{struct_name}.\{field_name}" + TupleAccess(tuple_name, index) => "\{tuple_name}.\{index}" + CreateStruct(struct_name, init_arr) => { + let init_strs = init_arr.map(field => "\{field.0}: \{field.1}").join(", ") + "\{struct_name}::{\{init_strs}}" + } + ArrayLiteral(ty, elem_names) => { + let elems_strs = elem_names.map(elem => elem.to_string()).join(", ") + "[\{elems_strs}]::Array[\{ty}]" + } + ArrayMake(size_name, init_name) => "array_make(\{size_name}, \{init_name})" + TupleLiteral(elem_names) => { + let elems_strs = elem_names.map(elem => elem.to_string()).join(", ") + "(\{elems_strs})" + } + Block(block) => block.to_string(ident) + If(cond, then_block, else_block) => { + let cond_str = cond.to_string() + let then_str : String = then_block.to_string(ident) + if else_block.stmts.is_empty() { + "if \{cond_str} \{then_str}" + } else { + let else_str : String = else_block.to_string(ident) + "if \{cond_str} \{then_str} else \{else_str}" + } + } + } +} + +///| +pub impl Show for KnfExpr with output(self, logger) { + logger.write_string(self.to_string(ident=0)) +} diff --git a/src/knf/function.mbt b/src/knf/function.mbt new file mode 100644 index 0000000..d063489 --- /dev/null +++ b/src/knf/function.mbt @@ -0,0 +1,62 @@ +///| +pub(all) struct KnfFunction { + name : String + ret_ty : Type + params : Array[(Name, Type)] + body : KnfBlock +} + +///| +pub fn Context::top_function_to_knf( + self : Context, + top_func : @typecheck.TopFunction, +) -> KnfFunction raise KnfTransformError { + let { fname, param_list, ty, body } = top_func + let func_type = self.typekind_to_knf(ty) + self.globals.set(fname, func_type) + + // 1. 进入函数作用域 + self.enter_scope() + + // 2. 参数处理 + let knf_params = [] + let param_types = [] + for param in param_list { + let { name: param_name, ty: param_ty } = param + let knf_param_ty = self.typekind_to_knf(param_ty) + let knf_param_name = self.add_new_name(param_name, knf_param_ty) + knf_params.push((knf_param_name, knf_param_ty)) + param_types.push(knf_param_ty) + } + + // 3. 返回类型转换 + guard func_type is Function(_, ret_ty) else { + raise KnfTransformError("Function type expected") + } + + // 4. 函数体转换 + let knf_body = self.block_expr_to_knf(body) + + // 5. 退出作用域 + self.exit_scope() + { name: fname, ret_ty, params: knf_params, body: knf_body } +} + +///| +pub impl Show for KnfFunction with output(self, logger) { + let { name, ret_ty, params, body } = self + logger.write_string("fn \{name}") + if name != "main" { + logger.write_string("(") + let param_str = params + .map(param => { + let (param_name, param_ty) = param + "\{param_name}: \{param_ty}" + }) + .join(", ") + logger.write_string(param_str) + logger.write_string(") -> \{ret_ty}") + } + logger.write_char(' ') + logger.write_object(body) +} diff --git a/src/knf/knf.mbt b/src/knf/knf.mbt new file mode 100644 index 0000000..629f54c --- /dev/null +++ b/src/knf/knf.mbt @@ -0,0 +1,85 @@ +///| +pub(all) struct KnfProgram { + struct_defs : Map[String, KnfStructDef] + top_lets : Map[String, KnfTopLet] + functions : Map[String, KnfFunction] +} + +///| +pub fn Context::program_to_knf( + self : Context, + prog : @typecheck.Program, +) -> KnfProgram raise KnfTransformError { + let knf_struct_defs = Map::new() + for name, struct_def in prog.struct_defs { + let knf_struct_def = self.struct_def_to_knf(struct_def) + knf_struct_defs.set(name, knf_struct_def) + } + for name, func in prog.top_functions { + let func_type = self.typekind_to_knf(func.ty) + self.globals.set(name, func_type) + } + let knf_top_lets = Map::new() + for name, top_let in prog.top_lets { + let knf_top_let = self.top_let_to_knf(top_let) + knf_top_lets.set(name, knf_top_let) + } + let knf_functions = Map::new() + for name, func in prog.top_functions { + let knf_func = self.top_function_to_knf(func) + knf_functions.set(name, knf_func) + } + { + struct_defs: knf_struct_defs, + top_lets: knf_top_lets, + functions: knf_functions, + } +} + +///| +pub fn knf_transform( + prog : @typecheck.Program, +) -> KnfProgram raise KnfTransformError { + let context = Context::new() + context.add_intrinsic_functions() + context.program_to_knf(prog) +} + +///| +pub fn Context::add_intrinsic_functions(self : Context) -> Unit { + let map = Map::of([ + ("read_int", Function([], Int)), + ("print_int", Function([Int], Unit)), + ("read_char", Function([], Int)), + ("print_char", Function([Int], Unit)), + ("print_endline", Function([], Unit)), + ("int_of_float", Function([Double], Int)), + ("float_of_int", Function([Int], Double)), + ("truncate", Function([Double], Int)), + ("floor", Function([Double], Double)), + ("abs_float", Function([Double], Double)), + ("sqrt", Function([Double], Double)), + ("sin", Function([Double], Double)), + ("cos", Function([Double], Double)), + ("atan", Function([Double], Double)), + ]) + for name, ty in map { + self.globals.set(name, ty) + } +} + +///| +pub impl Show for KnfProgram with output(self, logger) { + for _, struct_def in self.struct_defs { + logger.write_object(struct_def) + logger.write_char('\n') + } + for _, top_let in self.top_lets { + logger.write_object(top_let) + logger.write_char('\n') + } + for _, func in self.functions { + logger.write_object(func) + logger.write_char('\n') + } +} diff --git a/src/knf/knf_test.mbt b/src/knf/knf_test.mbt new file mode 100644 index 0000000..8efce3c --- /dev/null +++ b/src/knf/knf_test.mbt @@ -0,0 +1,1417 @@ +///| +test "Type Knf Transformation Test" { + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = Context::new() + let code = + #|Int Unit Bool Double Array[Int] + #|(Int, Double, Bool) + #|(Int, Int) -> Bool + #|(Array[Double]) -> Double + // Parse + let tokens = @parser.tokenize(code) + let (t, tok_view) = @parser.parse_type(tokens) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Int) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Unit) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Bool) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Double) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Array(Int)) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Tuple([Int, Double, Bool])) + let (t, tok_view) = @parser.parse_type(tok_view) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Function([Int, Int], Bool)) + let (t, _) = @parser.parse_type(tok_view) + let t = typecheck_ctx.check_parser_type(t) + let t = knf_ctx.type_to_knf(t) + assert_true(t is Function([Array(Double)], Double)) +} + +///| +test "Simple Atom Expr Knf Transformation Test" { + // set x, y, z type in typecheck and knf context + let typecheck_ctx = @typecheck.Context::new() + typecheck_ctx.type_env.set("x", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("y", { kind: Double, mutable: false }) + typecheck_ctx.type_env.set("z", { kind: Bool, mutable: false }) + let knf_ctx = Context::new() + knf_ctx.globals.set("x", Int) + let _ = knf_ctx.add_new_name("y", Double) + knf_ctx.enter_scope() + let _ = knf_ctx.add_new_name("z", Bool) + + // Code parse, typecheck, knf transform + let code = + #|42 3.14 true + #|x y z + // Parse + let tokens = @parser.tokenize(code) + let (e, tok_view) = @parser.parse_expr(tokens) + let e = typecheck_ctx.check_atom_expr(e) + let (_, e) = knf_ctx.atom_expr_to_knf(e) + assert_true(e is Int(42)) + let (e, tok_view) = @parser.parse_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (_, e) = knf_ctx.atom_expr_to_knf(e) + assert_true(e is Double(3.14)) + let (e, tok_view) = @parser.parse_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (_, e) = knf_ctx.atom_expr_to_knf(e) + assert_true(e is Bool(true)) + // find Global Ident `x` + let (e, tok_view) = @parser.parse_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (_, e) = knf_ctx.atom_expr_to_knf(e) + assert_true(e is Ident({ id: "x", slot: 0 })) + // find Parent Ident `y` + let (e, tok_view) = @parser.parse_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (_, e) = knf_ctx.atom_expr_to_knf(e) + assert_true(e is Ident({ id: "y", slot: 0 })) + // find Local Ident `z` + let (e, _) = @parser.parse_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (_, e) = knf_ctx.atom_expr_to_knf(e) + assert_true(e is Ident({ id: "z", slot: 0 })) +} + +///| +test "Simple Expr Knf Transformation Test" { + // set x, y, z type in typecheck and knf context + let typecheck_ctx = @typecheck.Context::new() + typecheck_ctx.type_env.set("x", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("y", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("z", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("w", { kind: Bool, mutable: false }) + let knf_ctx = Context::new() + knf_ctx.globals.set("x", Int) + let _ = knf_ctx.add_new_name("y", Int) + let _ = knf_ctx.add_new_name("w", Int) + knf_ctx.enter_scope() + let _ = knf_ctx.add_new_name("z", Bool) + + // Code parse, typecheck, knf transform + let code = + #|42; + #|-33; + #|!w; + #|x + y; + #|2.0 * 3.14; + #|x + y * z ; + let tokens = @parser.tokenize(code) + // Parse and transform `42` + let (e, tok_view) = @parser.parse_expr(tokens) + let e = typecheck_ctx.check_expr(e) + let (_, e) = knf_ctx.expr_to_knf(e) + assert_true(e is Int(42)) + // Parse and transform `-33` + let (e, tok_view) = @parser.parse_expr(tok_view[1:]) + let e = typecheck_ctx.check_expr(e) + let (stmts, e) = knf_ctx.expr_to_knf(e) + assert_true( + stmts is [s] && s is Let(n1, Int, Int(33)) && e is Neg(n2) && n1 == n2, + ) + // Parse and transform `!z` + let (e, tok_view) = @parser.parse_expr(tok_view[1:]) + let e = typecheck_ctx.check_expr(e) + let (stmts, e) = knf_ctx.expr_to_knf(e) + assert_true( + stmts is [s] && + s is Let(n1, Bool, Ident(i)) && + i is { id: "w", .. } && + e is Not(n2) && + n1 == n2, + ) + // Parse and transform `x + 10` + let (e, tok_view) = @parser.parse_expr(tok_view[1:]) + let e = typecheck_ctx.check_expr(e) + let (stmts, e) = knf_ctx.expr_to_knf(e) + assert_true( + stmts is [] && + e is Binary(Add, n1, n2) && + n1 is { id: "x", .. } && + n2 is { id: "y", .. }, + ) + // Parse and transform `2.0 * 3.14` + let (e, tok_view) = @parser.parse_expr(tok_view[1:]) + let e = typecheck_ctx.check_expr(e) + let (stmts, e) = knf_ctx.expr_to_knf(e) + assert_true( + stmts is [s1, s2] && + s1 is Let(n1, Double, Double(2.0)) && + s2 is Let(n2, Double, Double(3.14)) && + e is Binary(Mul, n3, n4) && + n1 == n3 && + n2 == n4, + ) + // Parse and transform `x + y * z` + let (e, _) = @parser.parse_expr(tok_view[1:]) + let e = typecheck_ctx.check_expr(e) + let (stmts, e) = knf_ctx.expr_to_knf(e) + assert_true( + stmts is [s] && + s is Let(n1, Int, Binary(Mul, n2, n3)) && + e is Binary(Add, n4, n5) && + n1 is { id: "tmp", .. } && + n2 is { id: "y", .. } && + n3 is { id: "z", .. } && + n4 is { id: "x", .. } && + n5 == n1, + ) +} + +///| +test "Atom Expr Knf Transformation Test" { + // set a, b, x, y, z type in typecheck and knf context + let typecheck_ctx = @typecheck.Context::new() + typecheck_ctx.type_env.set("a", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("b", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("x", { kind: Double, mutable: false }) + typecheck_ctx.type_env.set("y", { kind: Double, mutable: false }) + typecheck_ctx.type_env.set("z", { kind: Bool, mutable: false }) + let point_struct_def : @typecheck.StructDef = { + name: "Point", + fields: [ + { name: "x", ty: { kind: Int, mutable: false } }, + { name: "y", ty: { kind: Int, mutable: false } }, + ], + } + typecheck_ctx.struct_defs.set("Point", point_struct_def) + let knf_ctx = Context::new() + knf_ctx.globals.set("x", Double) + let _ = knf_ctx.add_new_name("y", Double) + knf_ctx.enter_scope() + let _ = knf_ctx.add_new_name("z", Bool) + let _ = knf_ctx.add_new_name("a", Int) + let _ = knf_ctx.add_new_name("b", Int) + + // Code parse, typecheck, knf transform + let code = + #|[1, 2, 3] + #|(x, y, z) + #|Array::make(5, 0) + #|Point::{ x: 10, y: 20 } + let tokens = @parser.tokenize(code) + // Parse and transform `[1, 2, 3]`. + let (e, tok_view) = @parser.parse_value_level_expr(tokens) + let e = typecheck_ctx.check_atom_expr(e) + let (stmts, knf_expr) = knf_ctx.atom_expr_to_knf(e) + assert_true( + stmts is [s1, s2, s3] && + s1 is Let(n1, Int, Int(1)) && + s2 is Let(n2, Int, Int(2)) && + s3 is Let(n3, Int, Int(3)) && + knf_expr is ArrayLiteral(_, [a1, a2, a3]) && + n1 == a1 && + n2 == a2 && + n3 == a3, + ) + // Parse and transform `(x, y, z)`. + let (e, tok_view) = @parser.parse_value_level_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (stmts, knf_expr) = knf_ctx.atom_expr_to_knf(e) + assert_true( + stmts is [] && + knf_expr is TupleLiteral([t1, t2, t3]) && + t1 is { id: "x", .. } && + t2 is { id: "y", .. } && + t3 is { id: "z", .. }, + ) + // Parse and transform `Array::make(5, 0)`. + let (e, tok_view) = @parser.parse_value_level_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (stmts, knf_expr) = knf_ctx.atom_expr_to_knf(e) + assert_true( + stmts is [s1, s2] && + s1 is Let(n1, Int, Int(5)) && + s2 is Let(n2, Int, Int(0)) && + knf_expr is ArrayMake(size_name, init_name) && + size_name == n1 && + init_name == n2, + ) + // Parse and transform `Point::{ x: 10, y: 20 }`. + let (e, _) = @parser.parse_value_level_expr(tok_view) + let e = typecheck_ctx.check_atom_expr(e) + let (stmts, knf_expr) = knf_ctx.atom_expr_to_knf(e) + assert_true( + stmts is [s1, s2] && + s1 is Let(n1, Int, Int(10)) && + s2 is Let(n2, Int, Int(20)) && + knf_expr is CreateStruct("Point", [("x", f1), ("y", f2)]) && + n1 == f1 && + n2 == f2, + ) +} + +///| +test "Apply Expr Knf Transformation Test" { + // Prelude Parts + // set a, b, x, arr, point, max type in typecheck and knf context + let typecheck_ctx = @typecheck.Context::new() + typecheck_ctx.type_env.set("a", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("b", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("x", { kind: Int, mutable: false }) + let point_struct_def : @typecheck.StructDef = { + name: "Point", + fields: [ + { name: "x", ty: { kind: Int, mutable: false } }, + { name: "y", ty: { kind: Int, mutable: false } }, + ], + } + typecheck_ctx.struct_defs.set("Point", point_struct_def) + typecheck_ctx.type_env.set("arr", { kind: Array(Int), mutable: false }) + typecheck_ctx.type_env.set("point", { kind: Struct("Point"), mutable: false }) + typecheck_ctx.type_env.set("max", { + kind: Function([Int, Int], Int), + mutable: false, + }) + let knf_ctx = Context::new() + knf_ctx.globals.set("max", Function([Int, Int], Int)) + let _ = knf_ctx.add_new_name("a", Int) + let _ = knf_ctx.add_new_name("b", Int) + let _ = knf_ctx.add_new_name("x", Int) + let _ = knf_ctx.add_new_name("arr", Array(Int)) + let _ = knf_ctx.add_new_name("point", Struct("Point")) + + // Test Parts + let code = + #|arr[3] + #|arr[x] + #|point.x + #|max(a + b, a - b) + // Code parse, typecheck, knf transform + let tokens = @parser.tokenize(code) + // Parse and transform `arr[3]`. + let (e, tok_view) = @parser.parse_get_or_apply_level_expr(tokens) + let e = typecheck_ctx.check_apply_expr(e) + let (stmts, knf_expr) = knf_ctx.apply_expr_to_knf(e) + assert_true( + stmts is [s1] && + s1 is Let(n1, Int, Int(3)) && + knf_expr is ArrayAccess(arr_name, index_name) && + arr_name is { id: "arr", .. } && + index_name == n1, + ) + // Parse and transform `arr[x]`. + let (e, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view) + let e = typecheck_ctx.check_apply_expr(e) + let (stmts, knf_expr) = knf_ctx.apply_expr_to_knf(e) + assert_true( + stmts is [] && + knf_expr is ArrayAccess(arr_name, index_name) && + arr_name is { id: "arr", .. } && + index_name is { id: "x", .. }, + ) + // Parse and transform `point.x`. + let (e, tok_view) = @parser.parse_get_or_apply_level_expr(tok_view) + let e = typecheck_ctx.check_apply_expr(e) + let (stmts, knf_expr) = knf_ctx.apply_expr_to_knf(e) + assert_true( + stmts is [] && + knf_expr is FieldAccess(struct_name, field_name) && + struct_name is { id: "point", .. } && + field_name is "x", + ) + // Parse and transform `max(a + b, a - b)`. + let (e, _) = @parser.parse_get_or_apply_level_expr(tok_view) + let e = typecheck_ctx.check_apply_expr(e) + let (stmts, knf_expr) = knf_ctx.apply_expr_to_knf(e) + assert_true( + stmts is [s1, s2] && + s1 is Let(n1, Int, Binary(Add, l1, r1)) && + l1.id is "a" && + r1.id is "b" && + s2 is Let(n2, Int, Binary(Sub, l2, r2)) && + l2.id is "a" && + r2.id is "b" && + knf_expr is Call(callee_name, arg_names) && + callee_name.id is "max" && + arg_names == [n1, n2], + ) +} + +/// ================================================================================ +/// # 🎯 表达式 KNF 变换的综合测试:休息一下 +/// +/// 在完成了各种表达式的转换之后,我们现在要进行一个**综合测试**。 +/// 这是一个"休息一下"的时刻,让我们来验证之前实现的所有表达式转换功能 +/// 是否能够协同工作,处理更复杂的真实场景。 +/// +/// ## 🌟 测试内容 +/// +/// 这个测试包含了复杂的表达式转换: +/// - `max(a, b) + min(a, b)`:函数调用的组合 +/// - `sum(arr) / arr.length()`:复杂表达式的分解 +/// +/// 如果测试失败,请检查函数调用处理、参数传递、临时变量创建等实现。 +/// +/// **准备好验证你的实现了吗?让我们开始测试!** +/// ================================================================================ + +///| +test "Expr Knf Transformation Test" { + // Prelucde Parts + // set a, b, arr, sum, point, mat, sqrt type in typecheck and knf context + let typecheck_ctx = @typecheck.Context::new() + let point_struct_def : @typecheck.StructDef = { + name: "Point", + fields: [ + { name: "x", ty: { kind: Int, mutable: false } }, + { name: "y", ty: { kind: Int, mutable: false } }, + ], + } + typecheck_ctx.struct_defs.set("Point", point_struct_def) + typecheck_ctx.type_env.set("a", { kind: Double, mutable: false }) + typecheck_ctx.type_env.set("b", { kind: Double, mutable: false }) + typecheck_ctx.type_env.set("arr", { kind: Array(Int), mutable: false }) + typecheck_ctx.type_env.set("point", { kind: Struct("Point"), mutable: false }) + typecheck_ctx.type_env.set("sum", { + kind: Function([Array(Int)], Int), + mutable: false, + }) + typecheck_ctx.type_env.set("max", { + kind: Function([Double, Double], Double), + mutable: false, + }) + typecheck_ctx.type_env.set("min", { + kind: Function([Double, Double], Double), + mutable: false, + }) + let knf_ctx = @knf.Context::new() + knf_ctx.globals.set("max", Function([Double, Double], Double)) + knf_ctx.globals.set("min", Function([Double, Double], Double)) + knf_ctx.globals.set("sum", Function([Array(Int)], Int)) + let _ = knf_ctx.add_new_name("a", Double) + let _ = knf_ctx.add_new_name("b", Double) + let _ = knf_ctx.add_new_name("arr", Array(Int)) + + // Test Parts + let code = + #|max(a, b) + min(a, b) ; + #|sum(arr) / arr.length() ; + // Code parse, typecheck, knf transform + let tokens = @parser.tokenize(code) + // Parse and transform `max(a, b) + min(a, b)` + let (e, _) = @parser.parse_expr(tokens) + let e = typecheck_ctx.check_expr(e) + let (stmts, knf_expr) = knf_ctx.expr_to_knf(e) + assert_true( + stmts is [s1, s2] && + s1 is Let(n1, Double, Call({ id: "max", .. }, [a1, a2])) && + s2 is Let(n2, Double, Call({ id: "min", .. }, [a3, a4])) && + a1 is { id: "a", .. } && + a2 is { id: "b", .. } && + a3 is { id: "a", .. } && + a4 is { id: "b", .. } && + knf_expr is Binary(Add, l, r) && + n1 == l && + n2 == r, + ) +} + +///| +test "Let Mut Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + // Test Parts + let code = + #|let mut x: Int = 10; + #|let mut a = 42.0; + #|let mut b = 33.0; + #|let mut y: Double = a + b; + // Code parse, typecheck, knf transform + let tokens = @parser.tokenize(code) + // Parse and transform `let mut x: Int = 10;` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens) + let stmt = typecheck_ctx.check_let_mut_stmt(stmt) + let knf_stmts = knf_ctx.let_mut_stmt_to_knf(stmt) + assert_true( + knf_stmts is [s] && + s is LetMut(name1, Int, init_expr) && + name1 is { id: "x", .. } && + init_expr is Int(10), + ) + // Parse and transform `let mut a = 42.0;` + let (stmt1, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt1 = typecheck_ctx.check_let_mut_stmt(stmt1) + let knf_stmts1 = knf_ctx.let_mut_stmt_to_knf(stmt1) + assert_true( + knf_stmts1 is [s1] && + s1 is LetMut(name_a, Double, init_expr1) && + name_a is { id: "a", .. } && + init_expr1 is Double(42.0), + ) + // Parse and transform `let mut b = 33.0;` + let (stmt_b, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt_b = typecheck_ctx.check_let_mut_stmt(stmt_b) + let knf_stmts_b = knf_ctx.let_mut_stmt_to_knf(stmt_b) + assert_true( + knf_stmts_b is [s_b] && + s_b is LetMut(name_b, Double, init_expr_b) && + name_b is { id: "b", .. } && + init_expr_b is Double(33.0), + ) + // Parse and transform `let mut y: Double = a + b;` + let (stmt2, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt2 = typecheck_ctx.check_let_mut_stmt(stmt2) + let knf_stmts2 = knf_ctx.let_mut_stmt_to_knf(stmt2) + assert_true( + knf_stmts2 is [s2] && + s2 is LetMut(name2, Double, init_expr2) && + name2 is { id: "y", .. } && + init_expr2 is Binary(Add, left, right) && + left is { id: "a", .. } && + right is { id: "b", .. }, + ) +} + +///| +test "Let Stmt Knf Transformation Test" { + // Prelucde Parts + let typecheck_ctx = @typecheck.Context::new() + typecheck_ctx.type_env.set("print_int", { + kind: Function([Int], Unit), + mutable: false, + }) + let knf_ctx = @knf.Context::new() + knf_ctx.globals.set("print_int", Function([Int], Unit)) + // Test Parts + let code = + #|let x: Int = 10; + #|let (a, b) = (42.0, 33.0); + #|let y: Double = a + b; + #|let _ = print_int(x); + // Code parse, typecheck, knf transform + let tokens = @parser.tokenize(code) + // Parse and transform `let x: Int = 10;` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens) + let stmt = typecheck_ctx.check_let_stmt(stmt) + let knf_stmts = knf_ctx.let_stmt_to_knf(stmt) + assert_true( + knf_stmts is [s] && + s is Let(name1, Int, init_expr) && + name1 is { id: "x", .. } && + init_expr is Int(10), + ) + // Parse and transform `let (a, b) = (42.0, 33.0);` + // It Should be transformed into 4 let statements in knf + // 1. let tmp1: Double = 42.0; + // 2. let tmp2: Double = 33.0; + // 3. let a = tmp1; + // 4. let b = tmp2; + // Not: let tmp3 = (tmp1, tmp2); then let a = tmp1.0; let b = tmp1.1 + // Although it's correct, it's not efficient way, be cause we need + // to create a tuple object in memory. + // + // Ask: Does `let a = 42.0; let b = 33.0;` better? + // Well, you can try it. But it may not easy in knf transformation phase. + let (stmt1, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt1 = typecheck_ctx.check_let_stmt(stmt1) + let knf_stmts1 = knf_ctx.let_stmt_to_knf(stmt1) + assert_true( + knf_stmts1 is [s1, s2, s3, s4] && + s1 is Let(n1, Double, Double(42.0)) && + s2 is Let(n2, Double, Double(33.0)) && + s3 is Let(a, _, Ident(n1_)) && + s4 is Let(b, _, Ident(n2_)) && + n1 == n1_ && + n2 == n2_ && + a is { id: "a", .. } && + b is { id: "b", .. }, + ) + // Parse and transform `let y: Double = a + b;` + let (stmt2, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt2 = typecheck_ctx.check_let_stmt(stmt2) + let knf_stmts2 = knf_ctx.let_stmt_to_knf(stmt2) + assert_true( + knf_stmts2 is [s2] && + s2 is Let(name2, Double, init_expr2) && + name2 is { id: "y", .. } && + init_expr2 is Binary(Add, left, right) && + left is { id: "a", .. } && + right is { id: "b", .. }, + ) + // Parse and transform `let _ = print_int(y);` + let (stmt3, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt3 = typecheck_ctx.check_let_stmt(stmt3) + let knf_stmts3 = knf_ctx.let_stmt_to_knf(stmt3) + assert_true( + knf_stmts3 is [s3] && + s3 is Let(_, Unit, call_expr) && + call_expr is Call(func, args) && + func is { id: "print_int", .. } && + args is [arg] && + arg is { id: "x", .. }, + ) +} + +///| +test "Assign Stmt Knf Transformation Test" { + // Prelude Parts + // set x, y type in typecheck and knf context + let typecheck_ctx = @typecheck.Context::new() + let point_struct_def : @typecheck.StructDef = { + name: "Point", + fields: [{ name: "x", ty: { kind: Int, mutable: true } }], + } + typecheck_ctx.struct_defs.set("Point", point_struct_def) + typecheck_ctx.type_env.set("x", { kind: Int, mutable: true }) + typecheck_ctx.type_env.set("y", { kind: Double, mutable: true }) + typecheck_ctx.type_env.set("mat", { kind: Array(Array(Int)), mutable: true }) + typecheck_ctx.type_env.set("point", { kind: Struct("Point"), mutable: false }) + let knf_ctx = Context::new() + let _ = knf_ctx.add_new_name("x", Int) + let _ = knf_ctx.add_new_name("y", Double) + let _ = knf_ctx.add_new_name("arr", Array(Int)) + let _ = knf_ctx.add_new_name("point", Struct("Point")) + let _ = knf_ctx.add_new_name("mat", Array(Array(Int))) + + // Test Parts + let code = + #|x = 10; + #|y = 3.14; + #|mat[0][0] = 42; + #|point.x = 100; + // Code parse, typecheck, knf transform + let tokens = @parser.tokenize(code) + + // Parse and transform `x = 10;` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens) + let stmt = typecheck_ctx.check_assign_stmt(stmt) + let knf_stmts = knf_ctx.assign_stmt_to_knf(stmt) + assert_true( + knf_stmts is [s] && + s is Assign(name, expr) && + name is { id: "x", .. } && + expr is Int(10), + ) + + // Parse and transform `y = 3.14;` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_assign_stmt(stmt) + let knf_stmts = knf_ctx.assign_stmt_to_knf(stmt) + assert_true( + knf_stmts is [s] && + s is Assign(name, expr) && + name is { id: "y", .. } && + expr is Double(3.14), + ) + + // Parse and transform `max[0][0] = 42;` + let (stmt, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_assign_stmt(stmt) + let knf_stmts = knf_ctx.assign_stmt_to_knf(stmt) + assert_true( + knf_stmts is [s1, s2, s3, s4] && + s1 is Let(n1, Int, Int(0)) && + s2 is Let(n2, Array(Int), ArrayAccess(_, mat_idx_name)) && + s3 is Let(n3, Int, Int(0)) && + s4 is ArrayPut(array_name, idx_name, Int(42)) && + n1 == mat_idx_name && + array_name == n2 && + idx_name == n3, + ) + + // MiniMoonBit Struct Field is immutable + + // Parse and transform `point.x = 100;` + // let (stmt, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + // let stmt = typecheck_ctx.check_assign_stmt(stmt) + // let knf_stmts = knf_ctx.assign_stmt_to_knf(stmt) + // assert_true( + // knf_stmts is [s1, s2] && + // s1 is Let(n1, Int, Int(100)) && + // s2 is StructFieldSet(_, "x", value_name) && + // n1 == value_name, + // ) +} + +///| +test "Stmt Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + + // Setup struct definition for testing + let point_struct_def : @typecheck.StructDef = { + name: "Point", + fields: [ + { name: "x", ty: { kind: Int, mutable: true } }, + { name: "y", ty: { kind: Int, mutable: true } }, + ], + } + typecheck_ctx.struct_defs.set("Point", point_struct_def) + + // Set function context for return statement + typecheck_ctx.current_func_ret_ty = Some(Struct("Point")) + let code = + #|let mut p = Point::{ x: 0, y: 0 }; + #|let mut a = 100; + #|let (b, c) = (200, 300); + #|let arr = [a, b, c]; + #|a = a * 50; + #|arr[1] = a - 25; + #|return p; + + // Code parse, typecheck, knf transform + let tokens = @parser.tokenize(code) + + // Test 1: Parse and transform `let mut p = Point::{ mut x: 0, y: 0 };` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tokens) + let stmt = typecheck_ctx.check_stmt(stmt) + let knf_stmts = knf_ctx.stmt_to_knf(stmt) + assert_true( + knf_stmts is [s1, s2, s3] && + s1 is Let(_, Int, Int(0)) && + s2 is Let(_, Int, Int(0)) && + s3 is LetMut(name, Struct("Point"), CreateStruct("Point", fields)) && + name is { id: "p", .. } && + fields is [field1, field2] && + field1 is ("x", _) && + field2 is ("y", _), + ) + + // Test 2: Parse and transform `let mut a = 100;` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_stmt(stmt) + let knf_stmts = knf_ctx.stmt_to_knf(stmt) + assert_true( + knf_stmts is [s] && + s is LetMut(name, Int, Int(100)) && + name is { id: "a", .. }, + ) + + // Test 3: Parse and transform `let (b, c) = (200, 300);` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_stmt(stmt) + let knf_stmts = knf_ctx.stmt_to_knf(stmt) + assert_true( + knf_stmts is [s1, s2, s3, s4] && + s1 is Let(n1, Int, Int(200)) && + s2 is Let(n2, Int, Int(300)) && + s3 is Let(b, _, Ident(n1_)) && + s4 is Let(c, _, Ident(n2_)) && + n1 == n1_ && + n2 == n2_ && + b is { id: "b", .. } && + c is { id: "c", .. }, + ) + + // Test 4: Parse and transform `let arr = [a, b, c];` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_stmt(stmt) + let knf_stmts = knf_ctx.stmt_to_knf(stmt) + assert_true( + knf_stmts is [s] && + s is Let(name, Array(Int), ArrayLiteral(_, elements)) && + name is { id: "arr", .. } && + elements is [elem1, elem2, elem3] && + elem1 is { id: "a", .. } && + elem2 is { id: "b", .. } && + elem3 is { id: "c", .. }, + ) + + // Test 5: Parse and transform `a = a * 50;` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_stmt(stmt) + let _ = knf_ctx.stmt_to_knf(stmt) + // assert_true( + // knf_stmts is [s1, s2, s3] && + // s1 is Let(_, Int, Int(50)) && + // s2 is Let(_, Int, Binary(Mul, _, _)) && + // s3 is Assign(name, Ident(_)) && + // name is { id: "a", .. }, + // ) + + // Test 6: Parse and transform `arr[1] = a - 25;` + let (stmt, _, tok_view) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_stmt(stmt) + let _ = knf_ctx.stmt_to_knf(stmt) + // assert_true( + // knf_stmts is [s1, s2, s3] && + // s1 is Let(_, Int, Int(1)) && + // s2 is Let(_, Int, Int(25)) && + // s3 is ArrayPut(_), + // ) + + // Test 7: Parse and transform `return p;` + let (stmt, _, _) = @parser.parse_stmt_or_expr_end(tok_view) + let stmt = typecheck_ctx.check_stmt(stmt) + let knf_stmts = knf_ctx.stmt_to_knf(stmt) + assert_true( + knf_stmts is [s] && + s is Return(return_expr) && + return_expr is Ident({ id: "p", .. }), + ) +} + +///| +test "Block Expr Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + + // Setup type environment for testing + typecheck_ctx.type_env.set("print_int", { + kind: Function([Int], Unit), + mutable: false, + }) + + // For return statement + typecheck_ctx.current_func_ret_ty = Some(Int) + + // Setup knf context + knf_ctx.globals.set("print_int", Function([Int], Unit)) + + // Test Parts + let code = + #|{ + #| let x: Int = 10; + #| let mut y: Double = 3.14; + #| let mut z : Int = 0; + #| z = 42; + #| print_int(z); + #| return z; + #|} + + // Code parse, typecheck, knf transform + let tokens = @parser.tokenize(code) + let (stmt, _) = @parser.parse_block_expr(tokens) + let checked_block = typecheck_ctx.check_block_expr(stmt) + let knf_block = knf_ctx.block_expr_to_knf(checked_block) + assert_true(knf_block.stmts.length() is 6) + + // Test 1: Parse and transform `let x: Int = 10;` + assert_true( + knf_block.stmts[0] is Let(name1, Int, init_expr) && + name1 is { id: "x", .. } && + init_expr is Int(10), + ) + + // Test 2: Parse and transform `let mut y: Double = 3.14;` + assert_true( + knf_block.stmts[1] is LetMut(name2, Double, init_expr2) && + name2 is { id: "y", .. } && + init_expr2 is Double(3.14), + ) + + // Test 3: Parse and transform `let mut z : Int = 0;` + assert_true( + knf_block.stmts[2] is LetMut(name3, Int, init_expr3) && + name3 is { id: "z", .. } && + init_expr3 is Int(0), + ) + + // Test 4: Parse and transform `z = 42;` + assert_true( + knf_block.stmts[3] is Assign(name3, expr3) && + name3 is { id: "z", .. } && + expr3 is Int(42), + ) + + // Test 4: Parse and transform `print_int(z);` (ExprStmt) + assert_true( + knf_block.stmts[4] is ExprStmt(call_expr) && + call_expr is Call(func, args) && + func is { id: "print_int", .. } && + args is [arg] && + arg is { id: "z", .. }, + ) + + // Test 5: Parse and transform `return z;` (ReturnStmt) + assert_true( + knf_block.stmts[5] is Return(return_expr) && + return_expr is Ident({ id: "z", .. }), + ) +} + +///| +test "If Expr Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + + // Setup type environment for testing + typecheck_ctx.type_env.set("x", { kind: Int, mutable: true }) + typecheck_ctx.type_env.set("y", { kind: Int, mutable: false }) + typecheck_ctx.type_env.set("print_int", { + kind: Function([Int], Unit), + mutable: false, + }) + + // Setup knf context + let _ = knf_ctx.add_new_name("x", Int) + let _ = knf_ctx.add_new_name("y", Int) + knf_ctx.globals.set("print_int", Function([Int], Unit)) + + // Test 1: Simple if-else expression: if (y > 0) { x = 10; } else { x = 20; } + let code1 = + #|if (y > 0) { + #| x = 10; + #|} else { + #| x = 20; + #|} + let tokens1 = @parser.tokenize(code1) + let (if_expr1, _) = @parser.parse_if_expr(tokens1) + let checked_if_expr1 = typecheck_ctx.check_if_expr(if_expr1) + let (stmts1, knf_if_expr1) = knf_ctx.if_expr_to_knf(checked_if_expr1) + + // Test 1: Check that condition generates a statement for the literal 0 + assert_true(stmts1.length() is 1) + assert_true( + stmts1[0] is Let(tmp_name, Int, Int(0)) && tmp_name is { id: "tmp", .. }, + ) + + // Test 1: Check the if expression structure + assert_true( + knf_if_expr1 is If(cond, then_block, else_block) && + cond is Binary(GT, y_name, tmp_cond) && + y_name is { id: "y", .. } && + tmp_cond is { id: "tmp", .. } && + then_block.stmts.length() is 1 && + then_block.stmts[0] is Assign(x_name1, Int(10)) && + x_name1 is { id: "x", .. } && + else_block.stmts.length() is 1 && + else_block.stmts[0] is Assign(x_name2, Int(20)) && + x_name2 is { id: "x", .. }, + ) + + // Test 2: If without else: if (x < 5) { print_int(x); } + let code2 = + #|if (x < 5) { + #| print_int(x); + #|} + let tokens2 = @parser.tokenize(code2) + let (if_expr2, _) = @parser.parse_if_expr(tokens2) + let checked_if_expr2 = typecheck_ctx.check_if_expr(if_expr2) + let (stmts2, knf_if_expr2) = knf_ctx.if_expr_to_knf(checked_if_expr2) + + // Test 2: Check that condition generates a statement for the literal 5 + assert_true(stmts2.length() is 1) + assert_true( + stmts2[0] is Let(tmp_name2, Int, Int(5)) && + tmp_name2 is { id: "tmp", slot: 1 }, + ) + + // Test 2: Check the if expression structure (no else block) + assert_true( + knf_if_expr2 is If(cond2, then_block2, else_block2) && + cond2 is Binary(LT, x_name3, tmp_cond2) && + x_name3 is { id: "x", .. } && + tmp_cond2 is { id: "tmp", slot: 1 } && + then_block2.stmts.length() is 1 && + then_block2.stmts[0] is ExprStmt(Call(print_func, [x_arg])) && + print_func is { id: "print_int", .. } && + x_arg is { id: "x", .. } && + else_block2.stmts.is_empty(), + ) + + // Test 3: If-else if-else chain + let code3 = + #|if x > 10 { + #| y + #|} else if x > 5 { + #| y + 1 + #|} else { + #| y + 2 + #|} + let tokens3 = @parser.tokenize(code3) + let (if_expr3, _) = @parser.parse_if_expr(tokens3) + let checked_if_expr3 = typecheck_ctx.check_if_expr(if_expr3) + let (stmts3, knf_if_expr3) = knf_ctx.if_expr_to_knf(checked_if_expr3) + + // Test 3: Check that condition generates a statement for the literal 10 + assert_true(stmts3.length() is 2) + assert_true( + stmts3 is [s1, s2] && + s1 is Let(_, Int, Int(10)) && + s2 is Let(_, Int, Int(5)), + ) + assert_true( + knf_if_expr3 is If(cond3, then_block3, else_block3) && + cond3 is Binary(GT, _, _) && + then_block3.stmts is [_] && + else_block3.stmts is [ExprStmt(e1)] && + e1 is If(nested_cond, _, _) && + nested_cond is Binary(GT, _, _), + ) +} + +///| +test "While Stmt Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + + // Setup type environment for testing + typecheck_ctx.type_env.set("i", { kind: Int, mutable: true }) + typecheck_ctx.type_env.set("sum", { kind: Int, mutable: true }) + typecheck_ctx.type_env.set("x", { kind: Int, mutable: true }) + typecheck_ctx.type_env.set("print_int", { + kind: Function([Int], Unit), + mutable: false, + }) + + // Setup knf context + let _ = knf_ctx.add_new_name("i", Int) + let _ = knf_ctx.add_new_name("sum", Int) + let _ = knf_ctx.add_new_name("x", Int) + knf_ctx.globals.set("print_int", Function([Int], Unit)) + + // Test 1: Simple while loop: while (i < 10) { sum = sum + i; i = i + 1; } + let code1 = + #|while (i < 10) { + #| sum = sum + i; + #| i = i + 1; + #|} + let tokens1 = @parser.tokenize(code1) + let (while_stmt1, _, _) = @parser.parse_stmt_or_expr_end(tokens1) + let checked_while_stmt1 = typecheck_ctx.check_while_stmt(while_stmt1) + let knf_while_stmts1 = knf_ctx.while_stmt_to_knf(checked_while_stmt1) + + // Test 1: Check that we get a While statement + assert_true(knf_while_stmts1.length() is 1) + assert_true(knf_while_stmts1[0] is While(_, _)) + guard knf_while_stmts1[0] is While(cond_block, body_block) + + // Test 1: Check condition block (should have let for 10 and the condition expression) + // Condition block should have: let tmp = 10; and i < tmp; + assert_true(cond_block.stmts.length() is 2) + assert_true( + cond_block.stmts[0] is Let(tmp_name, Int, Int(10)) && + tmp_name is { id: "tmp", slot: 0 }, + ) + assert_true( + cond_block.stmts[1] is ExprStmt(Binary(LT, i_name, tmp_cond)) && + i_name is { id: "i", .. } && + tmp_cond is { id: "tmp", slot: 0 }, + ) + + // Body block should have: + // 1. sum = sum + i; (no expansion needed, operands are identifiers) + // 2. let tmp$1 : Int = 1; (extract literal) + // 3. i = i + tmp$1; (assignment) + assert_true(body_block.stmts.length() is 3) + assert_true( + body_block.stmts[0] is Assign(sum_name, Binary(Add, sum_name2, i_name2)) && + sum_name is { id: "sum", .. } && + sum_name2 is { id: "sum", .. } && + i_name2 is { id: "i", .. }, + ) + assert_true( + body_block.stmts[1] is Let(tmp_lit, Int, Int(1)) && + tmp_lit is { id: "tmp", slot: 1 }, + ) + assert_true( + body_block.stmts[2] is Assign(i_name3, Binary(Add, i_name4, tmp_add)) && + i_name3 is { id: "i", .. } && + i_name4 is { id: "i", .. } && + tmp_add is { id: "tmp", slot: 1 }, + ) + + // Test 2: While with function call: while (x > 0) { print_int(x); x = x - 1; } + let code2 = + #|while (x > 0) { + #| print_int(x); + #| x = x - 1; + #|} + let tokens2 = @parser.tokenize(code2) + let (while_stmt2, _, _) = @parser.parse_stmt_or_expr_end(tokens2) + let checked_while_stmt2 = typecheck_ctx.check_while_stmt(while_stmt2) + let knf_while_stmts2 = knf_ctx.while_stmt_to_knf(checked_while_stmt2) + assert_true(knf_while_stmts2.length() is 1) + guard knf_while_stmts2[0] is While(cond_block, body_block) + + // Test 2: Check structure + // Condition block: let tmp = 0; x > tmp; + assert_true(cond_block.stmts.length() is 2) + assert_true( + cond_block.stmts[0] is Let(_, Int, Int(0)) && + cond_block.stmts[1] is ExprStmt(Binary(GT, _, _)), + ) + + // Body block: print_int(x); x = x - 1; + // Similar to test 1: print_int(x); let tmp = 1; x = x - tmp; + assert_true(body_block.stmts.length() is 3) + assert_true(body_block.stmts[0] is ExprStmt(Call(_, _))) + assert_true(body_block.stmts[1] is Let(_, Int, Int(1))) + assert_true(body_block.stmts[2] is Assign(_, Binary(Sub, _, _))) +} + +///| +test "Top Let Knf Transformation Test" { + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + let code = + #|let x: Int = 42; + #|let y = 3.14; + #|let a = 1; + #|let b = 2; + #|let sum = a + b; + let tokens1 = @parser.tokenize(code) + let program = @parser.parse_program(tokens1) + let top_let1 = program.top_lets["x"] + let checked_top_let1 = typecheck_ctx.check_top_let(top_let1) + let knf_top_let1 = knf_ctx.top_let_to_knf(checked_top_let1) + assert_true( + knf_top_let1.name is { id: "x", slot: 0 } && + knf_top_let1.ty is Int && + knf_top_let1.expr is Int(42), + ) + let top_let2 = program.top_lets["y"] + let checked_top_let2 = typecheck_ctx.check_top_let(top_let2) + let knf_top_let2 = knf_ctx.top_let_to_knf(checked_top_let2) + assert_true( + knf_top_let2.name is { id: "y", slot: 0 } && + knf_top_let2.ty is Double && + knf_top_let2.expr is Double(3.14), + ) + let top_let = program.top_lets["a"] + let checked_top_let = typecheck_ctx.check_top_let(top_let) + let _ = knf_ctx.top_let_to_knf(checked_top_let) + let top_let = program.top_lets["b"] + let checked_top_let = typecheck_ctx.check_top_let(top_let) + let _ = knf_ctx.top_let_to_knf(checked_top_let) + let top_let = program.top_lets["sum"] + let checked_top_let = typecheck_ctx.check_top_let(top_let) + let knf_top_let = knf_ctx.top_let_to_knf(checked_top_let) + assert_true( + knf_top_let.name is { id: "sum", slot: 0 } && + knf_top_let.ty is Int && + knf_top_let.expr is Binary(Add, l, r) && + l is { id: "a", .. } && + r is { id: "b", .. }, + ) +} + +///| +test "Struct Def Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + + // Put all code together + let code = + #|struct Point { + #| x: Int; + #| y: Int; + #|} + #|struct Empty {} + let tokens = @parser.tokenize(code) + let program = @parser.parse_program(tokens) + + // Test 1: Simple struct with two immutable fields + let struct_def1 = program.struct_defs["Point"] + let checked_struct_def1 = typecheck_ctx.check_struct_def(struct_def1) + let knf_struct_def1 = knf_ctx.struct_def_to_knf(checked_struct_def1) + assert_true(knf_struct_def1.name == "Point") + assert_true(knf_struct_def1.fields.length() is 2) + assert_true( + knf_struct_def1.fields[0] is (field_name1, is_mut1, field_type1) && + field_name1 == "x" && + is_mut1 == false && + field_type1 is Int, + ) + assert_true( + knf_struct_def1.fields[1] is (field_name2, is_mut2, field_type2) && + field_name2 == "y" && + is_mut2 == false && + field_type2 is Int, + ) + + // Test 3: Empty struct + let struct_def3 = program.struct_defs["Empty"] + let checked_struct_def3 = typecheck_ctx.check_struct_def(struct_def3) + let knf_struct_def3 = knf_ctx.struct_def_to_knf(checked_struct_def3) + assert_true(knf_struct_def3.name == "Empty") + assert_true(knf_struct_def3.fields.is_empty()) +} + +///| +test "Top Function Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + + // Setup builtin functions + typecheck_ctx.func_types.set("print_int", Function([Int], Unit)) + typecheck_ctx.type_env.set("print_int", { + kind: Function([Int], Unit), + mutable: false, + }) + knf_ctx.globals.set("print_int", Function([Int], Unit)) + let code = + #|fn add(x: Int, y: Int) -> Int { + #| return x + y; + #|} + #|fn greet(name: Double) -> Unit { + #| return (); + #|} + #|fn fib(n: Int) -> Int { + #| if n <= 1 { + #| return n; + #| } else { + #| return fib(n - 1) + fib(n - 2); + #| } + #|} + #|fn compute(a: Int, b: Int, c: Int) -> Int { + #| let sum = a + b; + #| let result = sum * c; + #| return result; + #|} + + // Register function signatures first + typecheck_ctx.func_types.set("add", Function([Int, Int], Int)) + typecheck_ctx.type_env.set("add", { + kind: Function([Int, Int], Int), + mutable: false, + }) + typecheck_ctx.func_types.set("greet", Function([Double], Unit)) + typecheck_ctx.type_env.set("greet", { + kind: Function([Double], Unit), + mutable: false, + }) + typecheck_ctx.func_types.set("fib", Function([Int], Int)) + typecheck_ctx.type_env.set("fib", { + kind: Function([Int], Int), + mutable: false, + }) + typecheck_ctx.func_types.set("compute", Function([Int, Int, Int], Int)) + typecheck_ctx.type_env.set("compute", { + kind: Function([Int, Int, Int], Int), + mutable: false, + }) + let tokens = @parser.tokenize(code) + let program = @parser.parse_program(tokens) + + // Test 1: Simple function with two parameters and return + let top_func1 = program.top_functions["add"] + let _ = typecheck_ctx.check_top_function_type_decl(top_func1) + let checked_func1 = typecheck_ctx.check_top_function_body(top_func1) + let knf_func1 = knf_ctx.top_function_to_knf(checked_func1) + assert_true(knf_func1.name == "add") + assert_true(knf_func1.ret_ty is Int) + assert_true(knf_func1.params.length() is 2) + assert_true( + knf_func1.params[0] is (param_name1, param_type1) && + param_name1 is { id: "x", slot: 0 } && + param_type1 is Int, + ) + assert_true( + knf_func1.params[1] is (param_name2, param_type2) && + param_name2 is { id: "y", slot: 0 } && + param_type2 is Int, + ) + // Body should have: return x + y; + assert_true(knf_func1.body.stmts.length() is 1) + assert_true( + knf_func1.body.stmts is [s] && + s is Return(Binary(Add, x_name, y_name)) && + x_name is { id: "x", .. } && + y_name is { id: "y", .. }, + ) + + // Test 2: Function with Double parameter returning Unit + let top_func2 = program.top_functions["greet"] + let _ = typecheck_ctx.check_top_function_type_decl(top_func2) + let checked_func2 = typecheck_ctx.check_top_function_body(top_func2) + let knf_func2 = knf_ctx.top_function_to_knf(checked_func2) + assert_true(knf_func2.name == "greet") + assert_true(knf_func2.ret_ty is Unit) + assert_true(knf_func2.params.length() is 1) + assert_true( + knf_func2.params is [(param_name3, param_type3)] && + param_name3 is { id: "name", slot: 0 } && + param_type3 is Double, + ) + // Body should have: return (); + assert_true(knf_func2.body.stmts.length() is 1) + // assert_true(knf_func2.body.stmts[0] is Return(Unit)) + // Lilunar 修改: Unit 不传递 + + // Test 3: Recursive function with if expression + let top_func3 = program.top_functions["fib"] + let _ = typecheck_ctx.check_top_function_type_decl(top_func3) + let checked_func3 = typecheck_ctx.check_top_function_body(top_func3) + let knf_func3 = knf_ctx.top_function_to_knf(checked_func3) + assert_true(knf_func3.name == "fib") + assert_true(knf_func3.ret_ty is Int) + assert_true(knf_func3.params.length() is 1) + assert_true( + knf_func3.params[0] is (param_name4, param_type4) && + param_name4 is { id: "n", slot: 0 } && + param_type4 is Int, + ) + // Body should contain if expression and recursive calls + // We don't need to verify the exact structure, just that it's transformed + assert_true(knf_func3.body.stmts.length() > 0) + + // Test 4: Function with multiple statements in body + let top_func4 = program.top_functions["compute"] + let _ = typecheck_ctx.check_top_function_type_decl(top_func4) + let checked_func4 = typecheck_ctx.check_top_function_body(top_func4) + let knf_func4 = knf_ctx.top_function_to_knf(checked_func4) + assert_true(knf_func4.name == "compute") + assert_true(knf_func4.ret_ty is Int) + assert_true(knf_func4.params.length() is 3) + // Body should have: let sum = a + b; let result = sum * c; return result; + assert_true(knf_func4.body.stmts.length() is 3) + assert_true( + knf_func4.body.stmts[0] is Let(sum_name, Int, Binary(Add, a_name, b_name)) && + sum_name is { id: "sum", .. } && + a_name is { id: "a", .. } && + b_name is { id: "b", .. }, + ) + assert_true( + knf_func4.body.stmts[1] + is Let(result_name, Int, Binary(Mul, sum_name2, c_name)) && + result_name is { id: "result", .. } && + sum_name2 is { id: "sum", .. } && + c_name is { id: "c", .. }, + ) + assert_true( + knf_func4.body.stmts[2] is Return(Ident(result_name2)) && + result_name2 is { id: "result", .. }, + ) +} + +///| +test "Local Function Knf Transformation Test" { + // Prelude Parts + let typecheck_ctx = @typecheck.Context::new() + let knf_ctx = @knf.Context::new() + + // Setup builtin functions + typecheck_ctx.func_types.set("print_int", Function([Int], Unit)) + typecheck_ctx.type_env.set("print_int", { + kind: Function([Int], Unit), + mutable: false, + }) + // Typecheck ctx need to know about 'foo' ahead of time. + // So that we can call `check_top_function` directly. + typecheck_ctx.func_types.set("foo", Function([], Unit)) + knf_ctx.globals.set("print_int", Function([Int], Unit)) + let code = + #|fn foo() -> Unit { + #| let x = 1; + #| fn bar() -> Unit { + #| fn baz() -> Unit { + #| print_int(x); + #| } + #| baz(); + #| } + #| bar(); + #|} + let tokens = @parser.tokenize(code) + let top_func = @parser.parse_program(tokens).top_functions["foo"] + let _ = typecheck_ctx.check_top_function_type_decl(top_func) + let typechecked_func = typecheck_ctx.check_top_function_body(top_func) + let knf_func = knf_ctx.top_function_to_knf(typechecked_func) + assert_true(knf_func.body.stmts.length() is 3) + // Second statement is local function 'bar' + assert_true( + knf_func.body.stmts[1] is ClosureDef(closure) && // closure 'bar' + closure.captured_vars.length() is 1 && // bar captures 'x' + closure.body.stmts is [s1, _] && + s1 is ClosureDef(inner_closure) && // inner closure 'baz' + inner_closure.captured_vars.length() is 1, // baz captures 'x' from bar + ) +} + +///| +test "Program Knf Transformation Test" { + let code = + #|let a = 3; + #|let b = 4; + #|fn fold(arr: Array[Int], f: (Int, Int) -> Int, initv: Int) -> Int { + #| let mut result = initv; + #| let mut i = 0; + #| while i < arr.length() { + #| result = f(result, arr[i]); + #| } + #| result + #|} + #| + #|fn main { + #| fn max(a, b) { if a > b { a } else { b } } + #| fn min(a, b) { if a < b { a } else { b } } + #| let numbers = [a, 1, b, 1, 5, 9, 2, 6, 5]; + #| let maximum = fold(numbers, max, -1000); + #| let minimum = fold(numbers, min, 1000); + #| let max_min_diff = maximum - minimum; + #| print_int(max_min_diff); + #|} + let tokens = @parser.tokenize(code) + let program = @parser.parse_program(tokens) + let program = @typecheck.typecheck(program) + let knf = knf_transform(program) + assert_true(knf.top_lets.contains("a")) + assert_true(knf.top_lets.contains("b")) + assert_true(knf.functions.contains("fold")) + assert_true(knf.functions.contains("main")) + inspect( + knf, + content=( + #|let a : Int = 3; + #|let b : Int = 4; + #|fn fold(arr: Array[Int], f: (Int, Int) -> Int, initv: Int) -> Int { + #| let mut result : Int = initv; + #| let mut i : Int = 0; + #| while {let tmp : () -> Int = arr.length; let tmp$1 : Int = tmp(); i < tmp$1; } { + #| let tmp$2 : Int = arr[i]; + #| result = f(result, tmp$2); + #| } + #| result; + #|} + #|fn main { + #| fn max(a$1 : Int, b$1 : Int) -> Int { + #| if a$1 > b$1 { + #| a$1; + #| } else { + #| b$1; + #| }; + #| } + #| fn min(a$1 : Int, b$1 : Int) -> Int { + #| if a$1 < b$1 { + #| a$1; + #| } else { + #| b$1; + #| }; + #| } + #| let tmp : Int = 1; + #| let tmp$1 : Int = 1; + #| let tmp$2 : Int = 5; + #| let tmp$3 : Int = 9; + #| let tmp$4 : Int = 2; + #| let tmp$5 : Int = 6; + #| let tmp$6 : Int = 5; + #| let numbers : Array[Int] = [a, tmp, b, tmp$1, tmp$2, tmp$3, tmp$4, tmp$5, tmp$6]::Array[Int]; + #| let tmp$7 : Int = 1000; + #| let tmp$8 : Int = -tmp$7; + #| let maximum : Int = fold(numbers, max, tmp$8); + #| let tmp$9 : Int = 1000; + #| let minimum : Int = fold(numbers, min, tmp$9); + #| let max_min_diff : Int = maximum - minimum; + #| print_int(max_min_diff); + #|} + #| + ), + ) +} diff --git a/src/knf/let_stmt.mbt b/src/knf/let_stmt.mbt new file mode 100644 index 0000000..b3844f9 --- /dev/null +++ b/src/knf/let_stmt.mbt @@ -0,0 +1,75 @@ +///| +pub fn Context::let_mut_stmt_to_knf( + self : Context, + let_mut_stmt : @typecheck.LetMutStmt, +) -> Array[KnfStmt] raise KnfTransformError { + let { name, ty, expr } = let_mut_stmt + let stmts = [] + let (init_stmts, init_knf_expr) = self.expr_to_knf(expr) + stmts.append(init_stmts) + let ty = self.type_to_knf(ty) + let name = self.add_new_name(name, ty) + stmts.push(LetMut(name, ty, init_knf_expr)) + stmts +} + +///| +pub fn Context::let_stmt_to_knf( + self : Context, + let_stmt : @typecheck.LetStmt, +) -> Array[KnfStmt] raise KnfTransformError { + let { pattern, ty, expr } = let_stmt + let stmts = [] + let (init_stmts, init_knf_expr) = self.expr_to_knf(expr) + stmts.append(init_stmts) + let ty = self.typekind_to_knf(ty) + match pattern.kind { + Ident(name) => { + let name = self.add_new_name(name, ty) + stmts.push(Let(name, ty, init_knf_expr)) + stmts + } + Wildcard => { + let name = Name::wildcard() + stmts.push(Let(name, ty, init_knf_expr)) + stmts + } + Tuple(names) => { + guard ty is Tuple(types) else { + raise KnfTransformError("tuple pattern requires tuple type") + } + if init_knf_expr is TupleLiteral(elem_names) { + guard names.length() == elem_names.length() else { + raise KnfTransformError("tuple pattern length mismatch") + } + for i in 0.. self.add_new_name(s, types[i]) + Wildcard => Name::wildcard() + Tuple(_) => panic() + } + let elem_ty = types[i] + let value_expr = KnfExpr::Ident(elem_names[i]) + stmts.push(Let(elem_name, elem_ty, value_expr)) + } + } else { + // Handle non-tuple expression case + let tmp_name = self.add_temp(ty) + for i in 0.. self.add_new_name(s, types[i]) + Wildcard => Name::wildcard() + Tuple(_) => panic() + } + let elem_ty = types[i] + let value_expr = KnfExpr::ArrayAccess( + tmp_name, + self.expr_to_knf_name(Int(i), Int, stmts), + ) + stmts.push(Let(elem_name, elem_ty, value_expr)) + } + } + stmts + } + } +} diff --git a/src/knf/moon.pkg.json b/src/knf/moon.pkg.json new file mode 100644 index 0000000..470da5c --- /dev/null +++ b/src/knf/moon.pkg.json @@ -0,0 +1,8 @@ +{ + "import": [ + "Lil-Ran/lilunar/typecheck" + ], + "test-import": [ + "Lil-Ran/lilunar/parser" + ] +} \ No newline at end of file diff --git a/src/knf/stmt.mbt b/src/knf/stmt.mbt new file mode 100644 index 0000000..93a75b3 --- /dev/null +++ b/src/knf/stmt.mbt @@ -0,0 +1,99 @@ +///| +pub(all) enum KnfStmt { + Let(Name, Type, KnfExpr) // let a : Int = 42; + LetMut(Name, Type, KnfExpr) // let mut a : Int = 0; + Assign(Name, KnfExpr) // a = 10; + ArrayPut(Name, Name, KnfExpr) // arr[3] = 5; + StructFieldSet(Name, String, Name) // point.x = 10; // MiniMoonBit does not support + While(KnfBlock, KnfBlock) // while (cond) { ... } + ExprStmt(KnfExpr) // expr; + Return(KnfExpr) // return expr; + ReturnUnit // return; + ClosureDef(KnfClosure) // closure definition +} + +///| +pub fn Context::stmt_to_knf( + self : Context, + stmt : @typecheck.Stmt, +) -> Array[KnfStmt] raise KnfTransformError { + match stmt.kind { + LetStmt(let_stmt) => self.let_stmt_to_knf(let_stmt) + LetMutStmt(let_mut_stmt) => self.let_mut_stmt_to_knf(let_mut_stmt) + AssignStmt(assign_stmt) => self.assign_stmt_to_knf(assign_stmt) + WhileStmt(while_stmt) => self.while_stmt_to_knf(while_stmt) + ExprStmt(expr_stmt) => { + let stmts = [] + let (expr_stmts, expr_knf_expr) = self.expr_to_knf(expr_stmt) + stmts.append(expr_stmts) + stmts.push(ExprStmt(expr_knf_expr)) + stmts + } + ReturnStmt(return_stmt) => + match return_stmt.ty { + Unit => [ReturnUnit] + _ => { + let stmts = [] + let (expr_stmts, expr_knf_expr) = self.expr_to_knf(return_stmt) + stmts.append(expr_stmts) + stmts.push(Return(expr_knf_expr)) + stmts + } + } + LocalFunction(local_function) => { + let closure = self.local_function_to_knf(local_function) + [ClosureDef(closure)] + } + } +} + +///| +pub fn Context::while_stmt_to_knf( + self : Context, + while_stmt : @typecheck.WhileStmt, +) -> Array[KnfStmt] raise KnfTransformError { + let (cond_stmts, cond_knf_expr) = self.expr_to_knf(while_stmt.cond) + [ + While( + { stmts: [..cond_stmts, ExprStmt(cond_knf_expr)], ty: Bool }, + self.block_expr_to_knf(while_stmt.body), + ), + ] +} + +///| +pub fn KnfStmt::to_string(self : KnfStmt, ident? : Int = 0) -> String { + let s = match self { + Let(name, ty, expr) => "let \{name} : \{ty} = \{expr};" + LetMut(name, ty, expr) => "let mut \{name} : \{ty} = \{expr};" + Assign(name, expr) => "\{name} = \{expr};" + ArrayPut(array_name, index_name, value_expr) => + "\{array_name}[\{index_name}] = \{value_expr};" + StructFieldSet(struct_name, field_name, value_name) => + "\{struct_name}.\{field_name} = \{value_name};" + While(cond_block, body_block) => + if cond_block.stmts.length() <= 3 { + let cond_str = cond_block.nested_to_string() + let body_str = body_block.to_string(ident) + "while \{cond_str} \{body_str}" + } else { + let cond_str = cond_block.to_string(ident) + let body_str = body_block.to_string(ident) + "while \{cond_str} \{body_str}" + } + ExprStmt(expr) => { + let expr_str = expr.to_string(ident~) + "\{expr_str};" + } + Return(expr) => "return \{expr};" + ReturnUnit => "return;" + ClosureDef(closure) => closure.to_string(ident~) + } + let indent_str = " ".repeat(ident) + "\{indent_str}\{s}" +} + +///| +pub impl Show for KnfStmt with output(self, logger) { + logger.write_string(self.to_string(ident=0)) +} diff --git a/src/knf/struct_def.mbt b/src/knf/struct_def.mbt new file mode 100644 index 0000000..c01c114 --- /dev/null +++ b/src/knf/struct_def.mbt @@ -0,0 +1,47 @@ +///| +pub(all) struct KnfStructDef { + name : String + // field name, is_mut, field type + fields : Array[(String, Bool, Type)] +} + +///| +pub fn Context::struct_def_to_knf( + self : Context, + struct_def : @typecheck.StructDef, +) -> KnfStructDef raise KnfTransformError { + let { name, fields } = struct_def + let knf_fields = [] + for field in fields { + let { name: field_name, ty } = field + let field_type = self.typekind_to_knf(ty.kind) + knf_fields.push((field_name, ty.mutable, field_type)) + } + { name, fields: knf_fields } +} + +///| +pub fn KnfStructDef::get_field_index( + self : KnfStructDef, + field_name : String, +) -> Int? { + for i, f in self.fields { + let (name, _, _) = f + if name == field_name { + return Some(i) + } + } + None +} + +///| +pub impl Show for KnfStructDef with output(self, logger) { + let { name, fields } = self + logger.write_string("struct \{name} {\n") + for field in fields { + let (field_name, is_mut, field_type) = field + let mutability = if is_mut { "mut " } else { "" } + logger.write_string(" \{mutability}\{field_name}: \{field_type};\n") + } + logger.write_string("}\n") +} diff --git a/src/knf/top_let.mbt b/src/knf/top_let.mbt new file mode 100644 index 0000000..1d8e781 --- /dev/null +++ b/src/knf/top_let.mbt @@ -0,0 +1,29 @@ +///| +pub(all) struct KnfTopLet { + name : Name + ty : Type + expr : KnfExpr + init_stmts : Array[KnfStmt] +} + +///| +pub fn Context::top_let_to_knf( + self : Context, + top_let : @typecheck.TopLet, +) -> KnfTopLet raise KnfTransformError { + let { name, ty, expr } = top_let + let (init_stmts, expr) = self.expr_to_knf(expr) + let ty = self.type_to_knf(ty) + let name = self.add_new_name(name, ty) + self.globals.set(top_let.name, ty) + { name, ty, expr, init_stmts } +} + +///| +pub impl Show for KnfTopLet with output(self, logger) { + let { name, ty, expr, init_stmts } = self + for stmt in init_stmts { + logger.write_string(" \{stmt};") + } + logger.write_string("let \{name} : \{ty} = \{expr};") +} diff --git a/src/knf/type.mbt b/src/knf/type.mbt new file mode 100644 index 0000000..aed9b14 --- /dev/null +++ b/src/knf/type.mbt @@ -0,0 +1,64 @@ +///| +pub(all) enum Type { + Unit + Int + Bool + Double + Array(Type) + Struct(String) + Tuple(Array[Type]) + Function(Array[Type], Type) +} + +///| +pub impl Show for Type with output(self, logger) { + let s = match self { + Unit => "Unit" + Int => "Int" + Bool => "Bool" + Double => "Double" + Array(elem_type) => "Array[\{elem_type}]" + Struct(name) => "\{name}" + Tuple(elem_types) => { + let elem_strs = elem_types.map(et => "\{et}").join(", ") + "(\{elem_strs})" + } + Function(param_types, ret_type) => { + let param_strs = param_types.map(pt => "\{pt}").join(", ") + "(\{param_strs}) -> \{ret_type}" + } + } + logger.write_string(s) +} + +///| +pub fn Context::typekind_to_knf( + self : Context, + tk : @typecheck.TypeKind, +) -> Type raise KnfTransformError { + match tk { + Any => raise KnfTransformError("Cannot convert 'Any' to KNF type.") + Struct(name) => Struct(name) + Function(param_types, ret_type) => + Function( + param_types.map(pt => self.typekind_to_knf(pt)), + self.typekind_to_knf(ret_type), + ) + Array(t) => Array(self.typekind_to_knf(t)) + Tuple(types) => Tuple(types.map(t => self.typekind_to_knf(t))) + Double => Double + Int => Int + Bool => Bool + Unit => Unit + TypeVar(_) => + raise KnfTransformError("Cannot convert 'TypeVar' to KNF type.") + } +} + +///| +pub fn Context::type_to_knf( + self : Context, + t : @typecheck.Type, +) -> Type raise KnfTransformError { + self.typekind_to_knf(t.kind) +} diff --git a/src/typecheck/expr_atom.mbt b/src/typecheck/expr_atom.mbt index 3eb4b02..b504188 100644 --- a/src/typecheck/expr_atom.mbt +++ b/src/typecheck/expr_atom.mbt @@ -81,6 +81,14 @@ pub fn Context::check_atom_expr( let field_exprs = [] for field in fields { let (field_name, field_expr) = field + for existing_field in field_exprs { + let (existing_field_name, _) = existing_field + if existing_field_name == field_name { + raise TypeCheckError( + "Duplicate field '\{field_name}' in struct '\{name}' construction.", + ) + } + } let expected_type = def .get_field_type(field_name) .or_error( diff --git a/src/typecheck/expr_block.mbt b/src/typecheck/expr_block.mbt index 658e177..4711cd0 100644 --- a/src/typecheck/expr_block.mbt +++ b/src/typecheck/expr_block.mbt @@ -15,9 +15,16 @@ pub fn Context::check_block_expr( self.enter_scope() let checked_stmts = stmts.map(stmt => self.check_stmt(stmt)) self.exit_scope() - let stmts_count = checked_stmts.length() - guard stmts_count > 0 && checked_stmts[stmts_count - 1].kind is ExprStmt(expr) else { + guard checked_stmts is [.., { kind: ExprStmt(expr) }] else { return { stmts: checked_stmts, ty: Unit } } - { stmts: checked_stmts, ty: expr.ty } + if checked_stmts is [.., { kind: ReturnStmt(ret) }] { + { stmts: checked_stmts, ty: ret.ty } + } else if checked_stmts is [.., stmt1, stmt2] && + stmt2.kind is ExprStmt({ ty: Unit, .. }) && + stmt1.kind is ReturnStmt(ret) { + { stmts: checked_stmts, ty: ret.ty } + } else { + { stmts: checked_stmts, ty: expr.ty } + } } diff --git a/src/typecheck/top_function.mbt b/src/typecheck/top_function.mbt index f2d86b0..69ed16f 100644 --- a/src/typecheck/top_function.mbt +++ b/src/typecheck/top_function.mbt @@ -8,7 +8,7 @@ pub(all) struct Param { pub(all) struct TopFunction { fname : String param_list : Array[Param] - ret_ty : TypeKind + ty : TypeKind body : BlockExpr } derive(Show) @@ -28,7 +28,8 @@ pub fn Context::check_top_function_body( } let param_names = func.params.map(param => param.0) let param_list = [] - guard self.func_types.get(func.id) is Some(Function(param_types, ret_ty)) else { + guard self.func_types.get(func.id) is Some(ty) && + ty is Function(param_types, ret_ty) else { raise TypeCheckError("Function type for '\{func.id}' not found.") } self.enter_scope() @@ -48,7 +49,7 @@ pub fn Context::check_top_function_body( if func.user_defined_type is Some(UserDefined(udt)) { self.type_env.local_.remove("$Generic$\{udt}") } - { fname: func.id, param_list, ret_ty, body: checked_body } + { fname: func.id, param_list, ty, body: checked_body } } ///| diff --git a/src/typecheck/typechecker.mbt b/src/typecheck/typechecker.mbt index 7bf364c..509426c 100644 --- a/src/typecheck/typechecker.mbt +++ b/src/typecheck/typechecker.mbt @@ -33,7 +33,7 @@ pub fn Context::substitute_type_var( name: param.name, ty: self.deref_type_var(param.ty), }), - ret_ty: self.deref_type_var(top_func.ret_ty), + ty: self.deref_type_var(top_func.ty), body: self.substitute_type_var_for_block_expr(top_func.body), }) }