feat: 表达式作为值 — if/else 和 block 可产生值
This commit is contained in:
@@ -451,6 +451,48 @@ static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) {
|
||||
return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load");
|
||||
}
|
||||
|
||||
// 块表达式: { stmt*; expr } → 最后表达式的值
|
||||
case AST_BLOCK: {
|
||||
LLVMValueRef result = NULL;
|
||||
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
|
||||
AstNode* stmt = node->as.block.stmts[i];
|
||||
bool is_last = (i == node->as.block.stmt_count - 1);
|
||||
if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) {
|
||||
result = codegen_expr(ctx, stmt->as.expr_stmt.expr);
|
||||
} else {
|
||||
codegen_stmt(ctx, stmt);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// if 表达式: if cond { a } else { b }
|
||||
case AST_IF_STMT: {
|
||||
if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; }
|
||||
LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond);
|
||||
if (!cond_val) return NULL;
|
||||
LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type);
|
||||
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res");
|
||||
LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder));
|
||||
LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then");
|
||||
LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else");
|
||||
LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge");
|
||||
LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, then_bb);
|
||||
LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block);
|
||||
if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca);
|
||||
LLVMBuildBr(ctx->builder, merge_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, else_bb);
|
||||
LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block);
|
||||
if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca);
|
||||
LLVMBuildBr(ctx->builder, merge_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, merge_bb);
|
||||
return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val");
|
||||
}
|
||||
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
+17
-1
@@ -258,7 +258,23 @@ static AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error
|
||||
AstNode* left = NULL;
|
||||
|
||||
// 前缀解析
|
||||
if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) {
|
||||
if (tok->kind == TOK_IF) {
|
||||
// if-expr: if cond { then } else { else } → AST_IF_STMT (表达式位置)
|
||||
const Token* if_tok = advance(p);
|
||||
AstNode* cond = parse_expr(p, error);
|
||||
if (!cond) return NULL;
|
||||
AstNode* then_block = parse_block(p, error);
|
||||
if (!then_block) return NULL;
|
||||
AstNode* else_block = NULL;
|
||||
if (match(p, TOK_ELSE)) {
|
||||
if (peek(p)->kind == TOK_IF)
|
||||
else_block = parse_expr_prec(p, min_prec, error); // else if
|
||||
else
|
||||
else_block = parse_block(p, error);
|
||||
if (!else_block) return NULL;
|
||||
}
|
||||
left = ast_make_if(p->arena, cond, then_block, else_block, tok_loc(if_tok));
|
||||
} else if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) {
|
||||
left = parse_unary(p, error);
|
||||
} else if (tok->kind == TOK_LPAREN) {
|
||||
left = parse_group(p, error);
|
||||
|
||||
@@ -423,6 +423,8 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
|
||||
case AST_ENUM_VARIANT: analyze_enum_variant(node, scope, errors, a); break;
|
||||
case AST_INDEX_EXPR: analyze_index_expr(node, scope, errors, a); break;
|
||||
case AST_METHOD_CALL: analyze_method_call(node, scope, errors, a); break;
|
||||
case AST_IF_STMT: analyze_node(node, scope, errors, a); break;
|
||||
case AST_BLOCK: analyze_node(node, scope, errors, a); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
@@ -602,6 +604,17 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
|
||||
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
|
||||
analyze_node(node->as.block.stmts[i], scope, errors, a);
|
||||
}
|
||||
// 表达式作为值: 块类型 = 最后一条表达式语句的类型
|
||||
if (node->as.block.stmt_count > 0) {
|
||||
AstNode* last = node->as.block.stmts[node->as.block.stmt_count - 1];
|
||||
if (last->kind == AST_EXPR_STMT) {
|
||||
TypeKind ek = last->as.expr_stmt.expr->type.kind;
|
||||
if (ek != TYPE_ERROR && ek != TYPE_VOID) {
|
||||
node->type.kind = ek;
|
||||
node->type.struct_name = last->as.expr_stmt.expr->type.struct_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case AST_LET_STMT: {
|
||||
@@ -773,6 +786,19 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
|
||||
if (node->as.if_stmt.else_block) {
|
||||
analyze_node(node->as.if_stmt.else_block, scope, errors, a);
|
||||
}
|
||||
// 表达式作为值: if 类型 = 两个分支的共同非 void 类型
|
||||
{
|
||||
AstNode* tb = node->as.if_stmt.then_block;
|
||||
AstNode* eb = node->as.if_stmt.else_block;
|
||||
if (tb && eb) {
|
||||
TypeKind tt = tb->type.kind, et = eb->type.kind;
|
||||
if (tt == et && tt != TYPE_VOID && tt != TYPE_ERROR) {
|
||||
node->type.kind = tt;
|
||||
if (tt == TYPE_STRUCT && tb->type.struct_name)
|
||||
node->type.struct_name = tb->type.struct_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case AST_WHILE_STMT:
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
fn main() -> i64 {
|
||||
// if 作为表达式: let x = if cond { val1 } else { val2 }
|
||||
let a = if true { 10; } else { 20; };
|
||||
print_i64(a); // 10
|
||||
|
||||
let b = if false { 10; } else { 20; };
|
||||
print_i64(b); // 20
|
||||
|
||||
// 嵌套 if 表达式
|
||||
let c = if a > 5 { 100; } else { 0; };
|
||||
print_i64(c); // 100
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user