diff --git a/src/parser/ast.mbt b/src/parser/ast.mbt index 3c69d86..e739a6d 100644 --- a/src/parser/ast.mbt +++ b/src/parser/ast.mbt @@ -35,9 +35,6 @@ enum MulDivRemOp { Rem } derive(Show) -///| -struct Block(Array[Stmt], Expr?) derive(Show) - ///| enum Pattern { Wildcard @@ -68,7 +65,7 @@ enum Expr { Tuple(Array[Expr]) Array(Array[Expr]) Identifier(String) - Block(Block) + Block(Array[Stmt]) } derive(Show) ///| @@ -77,7 +74,7 @@ struct Function { user_defined_type : Type? params : Array[(String, Type?)] return_type : Type? - body : Block + body : Array[Stmt] } derive(Show) ///| @@ -94,7 +91,7 @@ enum Stmt { Assign(Expr, Expr) While(Expr, Array[Stmt]) Expr(Expr) - Return(Expr?) + Return(Expr) LocalFunction(Function) } derive(Show) @@ -144,24 +141,26 @@ fn parse_type( (Array(elem_type), rest) } [LParen, .. rest] => { - // XXX: function_type has at least one type in the argument list? - let (first_type, rest) = parse_type(rest) - let types = [first_type] + let types = [] loop rest { - [Comma, .. r] => { - let (next_type, r) = parse_type(r) - types.push(next_type) - continue r - } [RParen, Arrow, .. r] => { let (return_type, r) = parse_type(r) (Function(types, return_type), r) } [RParen, .. r] => (Tuple(types), r) - _ => - raise ParseError( - "Expected ',' or ')' or ')' '->' in tuple/function type", - ) + r => { + let (type_, r) = parse_type(r) + types.push(type_) + match r { + [Comma, .. r] => continue r + [RParen, ..] => continue r + _ => + raise ParseError( + "Expected ',' or ')' or ')' '->' in tuple/function type", + ) + } + continue r + } } } [UpperIdentifier(name), LBracket, .. rest] => { @@ -257,6 +256,201 @@ fn parse_enum_decl( } } +///| +fn parse_optional_type_annotation( + tokens : ArrayView[Token], +) -> (Type?, ArrayView[Token]) raise ParseError { + if tokens is [Colon, .. r] { + let (t, r) = parse_type(r) + (Some(t), r) + } else { + (None, tokens) + } +} + +///| +fn parse_let_stmt_type_expr( + tokens : ArrayView[Token], +) -> (Type?, Expr, ArrayView[Token]) raise ParseError { + let (type_, rest) = parse_optional_type_annotation(tokens) + guard rest is [Assign, .. rest] else { + raise ParseError( + "Expected '=' after identifier or type annotation in let declaration", + ) + } + let (expr, rest) = parse_expr(rest) + guard rest is [Semicolon, .. rest] else { + raise ParseError("Expected ';' after let statement") + } + (type_, expr, rest) +} + +///| +/// Returns `(Stmt, is_end_expr, rest_tokens)`. +/// Semicolon is consumed for statements. +/// End expr is wrapped in `Stmt::Return`. +fn parse_stmt_or_expr_end( + tokens : ArrayView[Token], +) -> (Stmt, Bool, ArrayView[Token]) raise ParseError { + match tokens { + + // let_tuple_stmt + [Let, LParen, .. rest] => { + let bindings : Array[Binding] = [] + let rest = loop rest { + [RParen, .. r] => r + [LowerIdentifier(_) | UpperIdentifier(_) | Wildcard as binding, .. r] => { + bindings.push( + match binding { + LowerIdentifier(name) | UpperIdentifier(name) => Identifier(name) + Wildcard => Wildcard + _ => raise ParseError("Unreachable") + }, + ) + match r { + [Comma, .. r] => continue r + [RParen, ..] => continue r + _ => raise ParseError("Expected ',' or ')' after let binding") + } + } + _ => raise ParseError("Unexpected token in let tuple binding list") + } + let (type_, expr, rest) = parse_let_stmt_type_expr(rest) + (LetTuple(bindings, type_, expr), false, rest) + } + + // let_mut_stmt + [Let, Mut, LowerIdentifier(id) | UpperIdentifier(id), .. rest] => { + let (type_, expr, rest) = parse_let_stmt_type_expr(rest) + (LetMut(id, type_, expr), false, rest) + } + + // let_stmt + [Let, LowerIdentifier(_) | UpperIdentifier(_) | Wildcard as b, .. rest] => { + let binding : Binding = match b { + LowerIdentifier(name) | UpperIdentifier(name) => Identifier(name) + Wildcard => Wildcard + _ => raise ParseError("Unreachable") + } + let (type_, expr, rest) = parse_let_stmt_type_expr(rest) + (Let(binding, type_, expr), false, rest) + } + + // nontop_fn_decl + [Fn, LowerIdentifier(id) | UpperIdentifier(id), LParen, .. rest] => { + let params = [] + let rest = loop rest { + [RParen, .. r] => r + [LowerIdentifier(param_name) | UpperIdentifier(param_name), .. r] => { + let (param_type, r) = parse_optional_type_annotation(r) + params.push((param_name, param_type)) + match r { + [Comma, .. r] => continue r + [RParen, ..] => continue r + _ => + raise ParseError("Expected ',' or ')' after function parameter") + } + } + _ => raise ParseError("Unexpected token in function parameter list") + } + let (return_type, rest) = if rest is [Arrow, .. r] { + let (t, r) = parse_type(r) + (Some(t), r) + } else { + (None, rest) + } + guard parse_block_expr(rest) is (Block(body), rest) + ( + LocalFunction({ id, user_defined_type: None, params, return_type, body }), + false, + rest, + ) + } + + // return_stmt + [Return, Semicolon, .. rest] => (Return(Literal(Unit)), false, rest) + [Return, .. rest] => { + let (expr, rest) = parse_expr(rest) + guard rest is [Semicolon, .. rest] else { + raise ParseError("Expected ';' after return statement") + } + (Return(expr), false, rest) + } + + // while_stmt + [While, .. rest] => { + let (cond_expr, rest) = parse_expr(rest) + let stmts = [] + loop rest { + [RCurlyBracket, .. r] => (While(cond_expr, stmts), false, r) + r => { + let (stmt, is_end_expr, rest) = parse_stmt_or_expr_end(r) + if is_end_expr { + raise ParseError( + "Unexpected return expression in while statement body", + ) + } + stmts.push(stmt) + continue rest + } + } + } + + // assign_stmt, expr_stmt, end_expr + _ => { + let (expr, rest) = parse_expr(tokens) + match rest { + [Assign, .. r] => { + let valid_left_value = loop expr { + Identifier(_) => true + FieldAccess(base, _) => continue base + IndexAccess(base, _) => continue base + _ => false + } + guard valid_left_value else { + raise ParseError("Invalid left-hand side in assignment") + } + let (rhs_expr, r) = parse_expr(r) + guard r is [Semicolon, .. r] else { + raise ParseError("Expected ';' after assignment statement") + } + (Assign(expr, rhs_expr), false, r) + } + [Semicolon, .. r] => (Expr(expr), false, r) + [RCurlyBracket, ..] => (Return(expr), true, rest) + _ => raise ParseError("Expected ';' or '}' after expression statement") + } + } + } +} + +///| +fn parse_block_expr( + tokens : ArrayView[Token], +) -> (Expr, ArrayView[Token]) raise ParseError { + guard tokens is [LCurlyBracket, .. rest] else { + raise ParseError("Expected '{' at start of block expression") + } + let stmts = [] + let rest = loop rest { + r => { + let (stmt, is_end_expr, r) = parse_stmt_or_expr_end(r) + stmts.push(stmt) + if is_end_expr { + guard r is [RCurlyBracket, .. r] else { + raise ParseError("Expected '}' at end of block expression") + } + break r + } else if r is [RCurlyBracket, .. r] { + stmts.push(Return(Literal(Unit))) + break r + } + continue r + } + } + (Block(stmts), rest) +} + ///| fn parse_if_expr( tokens : ArrayView[Token], @@ -382,6 +576,7 @@ fn parse_value_level_expr( // identifier_expr [LowerIdentifier(name) | UpperIdentifier(name), .. rest] => (Identifier(name), rest) + _ => raise ParseError("Unsupported value level expression") } } @@ -400,7 +595,23 @@ fn parse_get_or_apply_level_expr( result = IndexAccess(result, index) continue r } - [LParen, .. r] => ... + [LParen, .. r] => { + let args = [] + let rest = loop r { + [RParen, .. r] => r + r => { + let (arg, r) = parse_expr(r) + args.push(arg) + match r { + [Comma, .. r] => continue r + [RParen, ..] => break r + _ => raise ParseError("Expected ',' or ')' in function call args") + } + } + } + result = FunctionCall(result, args) + continue rest + } [Dot, UpperIdentifier(field_name) | LowerIdentifier(field_name), .. r] => { result = FieldAccess(result, field_name) continue r @@ -529,20 +740,10 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError { loop tokens[:] { [EOF] => { top_lets, top_functions, struct_defs, enum_defs } [Let, LowerIdentifier(id) | UpperIdentifier(id), .. rest] => { - let (type_, rest) = match rest { - [Colon, .. r] => { - let (t, r) = parse_type(r) - (Some(t), r) - } - [Assign, ..] => (None, rest) - _ => - raise ParseError( - "Expected ':' or '=' after identifier in let declaration", - ) - } + let (type_, rest) = parse_optional_type_annotation(rest) guard rest is [Assign, .. rest] else { raise ParseError( - "Expected '=' after type annotation in let declaration", + "Expected '=' after identifier or type annotation in let declaration", ) } let (expr, rest) = parse_expr(rest) @@ -553,7 +754,7 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError { continue rest } [Fn, LowerIdentifier("main"), ..] => { - let (body, rest) = parse_block_expr(tokens) + guard parse_block_expr(tokens) is (Block(body), rest) top_functions["main"] = { id: "main", user_defined_type: None, @@ -563,7 +764,56 @@ pub fn parse_program(tokens : Array[Token]) -> Program raise ParseError { } continue rest } - [Fn, ..] => ... + [Fn, .. rest] => { + let (user_defined_type, id, rest) = match rest { + [ + LBracket, + UpperIdentifier(type_), + RBracket, + UpperIdentifier(id) + | LowerIdentifier(id), + LParen, + .. r, + ] => { + let user_defined_type = Some(UserDefined(type_)) + (user_defined_type, id, r) + } + [UpperIdentifier(id) | LowerIdentifier(id), LParen, .. r] => + (None, id, r) + _ => + raise ParseError( + "Expected function name (with optional type) after 'fn'", + ) + } + let params = [] + let rest = loop rest { + [RParen, .. r] => r + [LowerIdentifier(param_name) | UpperIdentifier(param_name), Colon, .. r] => { + let (param_type, r) = parse_type(r) + params.push((param_name, Some(param_type))) + match r { + [Comma, .. r] => continue r + [RParen, ..] => continue r + _ => + raise ParseError("Expected ',' or ')' after function parameter") + } + } + _ => raise ParseError("Unexpected token in function parameter list") + } + guard rest is [Arrow, .. rest] else { + raise ParseError("Expected '->' after function parameter list") + } + let (return_type, rest) = parse_type(rest) + guard parse_block_expr(rest) is (Block(body), rest) + top_functions[id] = { + id, + user_defined_type, + params, + return_type: Some(return_type), + body, + } + continue rest + } [Struct, ..] as tokens => { let (struct_, rest) = parse_struct_decl(tokens) struct_defs[struct_.id] = struct_