diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 0b5295f..549b3a8 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -451,6 +451,48 @@ static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load"); } + // 块表达式: { stmt*; expr } → 最后表达式的值 + case AST_BLOCK: { + LLVMValueRef result = NULL; + for (size_t i = 0; i < node->as.block.stmt_count; i++) { + AstNode* stmt = node->as.block.stmts[i]; + bool is_last = (i == node->as.block.stmt_count - 1); + if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) { + result = codegen_expr(ctx, stmt->as.expr_stmt.expr); + } else { + codegen_stmt(ctx, stmt); + } + } + return result; + } + + // if 表达式: if cond { a } else { b } + case AST_IF_STMT: { + if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; } + LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond); + if (!cond_val) return NULL; + LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type); + LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res"); + LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder)); + LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then"); + LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else"); + LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge"); + LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, then_bb); + LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block); + if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca); + LLVMBuildBr(ctx->builder, merge_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, else_bb); + LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block); + if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca); + LLVMBuildBr(ctx->builder, merge_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, merge_bb); + return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val"); + } + default: return NULL; } diff --git a/src/parser/parser.c b/src/parser/parser.c index 66866b1..23973a5 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -258,7 +258,23 @@ static AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error AstNode* left = NULL; // 前缀解析 - if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) { + if (tok->kind == TOK_IF) { + // if-expr: if cond { then } else { else } → AST_IF_STMT (表达式位置) + const Token* if_tok = advance(p); + AstNode* cond = parse_expr(p, error); + if (!cond) return NULL; + AstNode* then_block = parse_block(p, error); + if (!then_block) return NULL; + AstNode* else_block = NULL; + if (match(p, TOK_ELSE)) { + if (peek(p)->kind == TOK_IF) + else_block = parse_expr_prec(p, min_prec, error); // else if + else + else_block = parse_block(p, error); + if (!else_block) return NULL; + } + left = ast_make_if(p->arena, cond, then_block, else_block, tok_loc(if_tok)); + } else if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) { left = parse_unary(p, error); } else if (tok->kind == TOK_LPAREN) { left = parse_group(p, error); diff --git a/src/sema/sema.c b/src/sema/sema.c index f004634..7eb8f4a 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -423,6 +423,8 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* case AST_ENUM_VARIANT: analyze_enum_variant(node, scope, errors, a); break; case AST_INDEX_EXPR: analyze_index_expr(node, scope, errors, a); break; case AST_METHOD_CALL: analyze_method_call(node, scope, errors, a); break; + case AST_IF_STMT: analyze_node(node, scope, errors, a); break; + case AST_BLOCK: analyze_node(node, scope, errors, a); break; default: break; } } @@ -602,6 +604,17 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* for (size_t i = 0; i < node->as.block.stmt_count; i++) { analyze_node(node->as.block.stmts[i], scope, errors, a); } + // 表达式作为值: 块类型 = 最后一条表达式语句的类型 + if (node->as.block.stmt_count > 0) { + AstNode* last = node->as.block.stmts[node->as.block.stmt_count - 1]; + if (last->kind == AST_EXPR_STMT) { + TypeKind ek = last->as.expr_stmt.expr->type.kind; + if (ek != TYPE_ERROR && ek != TYPE_VOID) { + node->type.kind = ek; + node->type.struct_name = last->as.expr_stmt.expr->type.struct_name; + } + } + } break; case AST_LET_STMT: { @@ -773,6 +786,19 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* if (node->as.if_stmt.else_block) { analyze_node(node->as.if_stmt.else_block, scope, errors, a); } + // 表达式作为值: if 类型 = 两个分支的共同非 void 类型 + { + AstNode* tb = node->as.if_stmt.then_block; + AstNode* eb = node->as.if_stmt.else_block; + if (tb && eb) { + TypeKind tt = tb->type.kind, et = eb->type.kind; + if (tt == et && tt != TYPE_VOID && tt != TYPE_ERROR) { + node->type.kind = tt; + if (tt == TYPE_STRUCT && tb->type.struct_name) + node->type.struct_name = tb->type.struct_name; + } + } + } break; case AST_WHILE_STMT: diff --git a/test/programs/30_if_expr.l b/test/programs/30_if_expr.l new file mode 100644 index 0000000..68e9c27 --- /dev/null +++ b/test/programs/30_if_expr.l @@ -0,0 +1,14 @@ +fn main() -> i64 { + // if 作为表达式: let x = if cond { val1 } else { val2 } + let a = if true { 10; } else { 20; }; + print_i64(a); // 10 + + let b = if false { 10; } else { 20; }; + print_i64(b); // 20 + + // 嵌套 if 表达式 + let c = if a > 5 { 100; } else { 0; }; + print_i64(c); // 100 + + return 0; +}