#include "codegen.h" #include #include #include #include #include // === 内部状态 === typedef struct VarEntry { const char* name; LLVMValueRef alloca; struct VarEntry* next; } VarEntry; typedef struct FnEntry { const char* name; LLVMValueRef fn; TypeKind ret; TypeKind* params; size_t pc; struct FnEntry* next; } FnEntry; typedef struct { LLVMContextRef context; // LLVM 19+ 需要显式 Context LLVMModuleRef module; LLVMBuilderRef builder; VarEntry* var_table; const char* error; FnEntry* fn_table; // printf 运行时支持(内置 print 函数委托给 printf) LLVMValueRef printf_fn; LLVMTypeRef printf_ty; } CgCtx; // === 类型映射(需要 Context)=== static LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { switch (kind) { case TYPE_I64: return LLVMInt64TypeInContext(ctx->context); case TYPE_F64: return LLVMDoubleTypeInContext(ctx->context); case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context); default: return LLVMVoidTypeInContext(ctx->context); } } static LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) { switch (lit->as.literal.lit_type) { case TYPE_I64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, true); case TYPE_F64: return LLVMConstReal(ty, lit->as.literal.f64_val); case TYPE_BOOL: return LLVMConstInt(ty, lit->as.literal.bool_val ? 1 : 0, false); default: return NULL; } } // === 变量表 === static LLVMValueRef find_var(CgCtx* ctx, const char* name) { for (VarEntry* e = ctx->var_table; e; e = e->next) if (strcmp(e->name, name) == 0) return e->alloca; return NULL; } static void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca) { VarEntry* e = malloc(sizeof(*e)); e->name = name; e->alloca = alloca; e->next = ctx->var_table; ctx->var_table = e; } // === 函数表 === static LLVMValueRef find_fn(CgCtx* ctx, const char* name) { for (FnEntry* e = ctx->fn_table; e; e = e->next) if (strcmp(e->name, name) == 0) return e->fn; return NULL; } static void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) { FnEntry* e = malloc(sizeof(*e)); e->name = name; e->fn = fn; e->ret = TYPE_VOID; e->params = NULL; e->pc = 0; e->next = ctx->fn_table; ctx->fn_table = e; } // === 向前声明 === static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node); static void codegen_stmt(CgCtx* ctx, AstNode* node); // === 表达式代码生成 === static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { if (!node) return NULL; switch (node->kind) { case AST_LITERAL_EXPR: return to_llvm_const(to_llvm_type(ctx, node->type.kind), node); case AST_IDENT_EXPR: { LLVMValueRef ptr = find_var(ctx, node->as.ident.name); if (!ptr) return NULL; return LLVMBuildLoad2(ctx->builder, to_llvm_type(ctx, node->type.kind), ptr, "load"); } case AST_UNARY_EXPR: { LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand); if (!operand) return NULL; if (node->as.unary.op == OP_NEG) { if (node->type.kind == TYPE_F64) return LLVMBuildFNeg(ctx->builder, operand, "fneg"); else return LLVMBuildNeg(ctx->builder, operand, "ineg"); } else { return LLVMBuildNot(ctx->builder, operand, "not"); } } case AST_BINARY_EXPR: { LLVMValueRef l = codegen_expr(ctx, node->as.binary.left); LLVMValueRef r = codegen_expr(ctx, node->as.binary.right); if (!l || !r) return NULL; bool is_float = (node->type.kind == TYPE_F64); switch (node->as.binary.op) { case OP_ADD: return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd") : LLVMBuildAdd(ctx->builder, l, r, "iadd"); case OP_SUB: return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub") : LLVMBuildSub(ctx->builder, l, r, "isub"); case OP_MUL: return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul") : LLVMBuildMul(ctx->builder, l, r, "imul"); case OP_DIV: return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv") : LLVMBuildSDiv(ctx->builder, l, r, "sdiv"); case OP_MOD: return LLVMBuildSRem(ctx->builder, l, r, "srem"); case OP_EQ: return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOEQ, l, r, "feq") : LLVMBuildICmp(ctx->builder, LLVMIntEQ, l, r, "ieq"); case OP_NE: return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealONE, l, r, "fne") : LLVMBuildICmp(ctx->builder, LLVMIntNE, l, r, "ine"); case OP_LT: return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOLT, l, r, "flt") : LLVMBuildICmp(ctx->builder, LLVMIntSLT, l, r, "ilt"); case OP_GT: return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOGT, l, r, "fgt") : LLVMBuildICmp(ctx->builder, LLVMIntSGT, l, r, "igt"); case OP_LE: return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOLE, l, r, "fle") : LLVMBuildICmp(ctx->builder, LLVMIntSLE, l, r, "ile"); case OP_GE: return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOGE, l, r, "fge") : LLVMBuildICmp(ctx->builder, LLVMIntSGE, l, r, "ige"); case OP_AND: return LLVMBuildAnd(ctx->builder, l, r, "and"); case OP_OR: return LLVMBuildOr(ctx->builder, l, r, "or"); default: return NULL; } } case AST_CALL_EXPR: { // === 内置 print 函数:委托给 printf === if (strcmp(node->as.call.name, "print_i64") == 0) { LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); if (!arg) return NULL; LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64"); LLVMValueRef printf_args[] = { fmt, arg }; return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, printf_args, 2, ""); } if (strcmp(node->as.call.name, "print_f64") == 0) { LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); if (!arg) return NULL; LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64"); LLVMValueRef printf_args[] = { fmt, arg }; return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, printf_args, 2, ""); } if (strcmp(node->as.call.name, "print_bool") == 0) { LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); if (!arg) return NULL; // 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择 LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg, LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp"); LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str"); LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str"); LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel"); return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, (LLVMValueRef[]){selected}, 1, ""); } // === 常规函数调用 === LLVMValueRef fn = find_fn(ctx, node->as.call.name); if (!fn) return NULL; LLVMValueRef args[16]; for (size_t i = 0; i < node->as.call.arg_count; i++) { args[i] = codegen_expr(ctx, node->as.call.args[i]); if (!args[i]) return NULL; } LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, (unsigned)node->as.call.arg_count, ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call"); } default: return NULL; } } // === 语句代码生成 === static void codegen_stmt(CgCtx* ctx, AstNode* node) { if (!node) return; switch (node->kind) { case AST_LET_STMT: { LLVMValueRef init_val = codegen_expr(ctx, node->as.let_stmt.init); if (!init_val) return; LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, to_llvm_type(ctx, node->as.let_stmt.init->type.kind), node->as.let_stmt.name); LLVMBuildStore(ctx->builder, init_val, alloca); add_var(ctx, node->as.let_stmt.name, alloca); break; } case AST_EXPR_STMT: codegen_expr(ctx, node->as.expr_stmt.expr); break; case AST_RETURN_STMT: if (node->as.return_stmt.expr) { LLVMValueRef val = codegen_expr(ctx, node->as.return_stmt.expr); if (val) LLVMBuildRet(ctx->builder, val); } else { LLVMBuildRetVoid(ctx->builder); } break; case AST_BLOCK: for (size_t i = 0; i < node->as.block.stmt_count; i++) { codegen_stmt(ctx, node->as.block.stmts[i]); } break; case AST_IF_STMT: { LLVMValueRef cond = codegen_expr(ctx, node->as.if_stmt.cond); if (!cond) return; LLVMBasicBlockRef cur_bb = LLVMGetInsertBlock(ctx->builder); LLVMValueRef cur_fn = LLVMGetBasicBlockParent(cur_bb); LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "then"); LLVMBasicBlockRef else_bb = node->as.if_stmt.else_block ? LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "else") : NULL; LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "if_merge"); if (else_bb) LLVMBuildCondBr(ctx->builder, cond, then_bb, else_bb); else LLVMBuildCondBr(ctx->builder, cond, then_bb, merge_bb); LLVMPositionBuilderAtEnd(ctx->builder, then_bb); codegen_stmt(ctx, node->as.if_stmt.then_block); if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder))) LLVMBuildBr(ctx->builder, merge_bb); if (else_bb) { LLVMPositionBuilderAtEnd(ctx->builder, else_bb); codegen_stmt(ctx, node->as.if_stmt.else_block); if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder))) LLVMBuildBr(ctx->builder, merge_bb); } LLVMPositionBuilderAtEnd(ctx->builder, merge_bb); break; } case AST_WHILE_STMT: { LLVMBasicBlockRef cur_bb = LLVMGetInsertBlock(ctx->builder); LLVMValueRef cur_fn = LLVMGetBasicBlockParent(cur_bb); LLVMBasicBlockRef cond_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_cond"); LLVMBasicBlockRef body_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_body"); LLVMBasicBlockRef exit_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_exit"); LLVMBuildBr(ctx->builder, cond_bb); LLVMPositionBuilderAtEnd(ctx->builder, cond_bb); LLVMValueRef cond = codegen_expr(ctx, node->as.while_stmt.cond); if (!cond) return; LLVMBuildCondBr(ctx->builder, cond, body_bb, exit_bb); LLVMPositionBuilderAtEnd(ctx->builder, body_bb); codegen_stmt(ctx, node->as.while_stmt.body); if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder))) LLVMBuildBr(ctx->builder, cond_bb); LLVMPositionBuilderAtEnd(ctx->builder, exit_bb); break; } default: break; } } // === 程序级代码生成 === LLVMModuleRef codegen_module(AstNode* ast, const char* name, const char** error_msg) { CgCtx ctx = {0}; ctx.context = LLVMContextCreate(); if (!ctx.context) { *error_msg = "无法创建 LLVM Context"; return NULL; } ctx.module = LLVMModuleCreateWithNameInContext(name, ctx.context); ctx.builder = LLVMCreateBuilderInContext(ctx.context); // 声明 C 标准库 printf(内置 print 函数依赖它) LLVMTypeRef printf_param_types[] = { LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0) }; ctx.printf_ty = LLVMFunctionType( LLVMInt32TypeInContext(ctx.context), printf_param_types, 1, true); ctx.printf_fn = LLVMAddFunction(ctx.module, "printf", ctx.printf_ty); // 第一遍:声明所有 L 函数 for (size_t i = 0; i < ast->as.program.fn_count; i++) { AstNode* fn = ast->as.program.functions[i]; LLVMTypeRef* ptypes = malloc(fn->as.function.param_count * sizeof(LLVMTypeRef)); for (size_t j = 0; j < fn->as.function.param_count; j++) ptypes[j] = to_llvm_type(&ctx, fn->as.function.params[j]->as.parameter.type); LLVMTypeRef fty = LLVMFunctionType( to_llvm_type(&ctx, fn->as.function.return_type), ptypes, (unsigned)fn->as.function.param_count, false); LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty); add_fn(&ctx, fn->as.function.name, lfn); free(ptypes); } // 第二遍:生成函数体 for (size_t i = 0; i < ast->as.program.fn_count; i++) { AstNode* fn = ast->as.program.functions[i]; LLVMValueRef lfn = find_fn(&ctx, fn->as.function.name); LLVMBasicBlockRef entry = LLVMAppendBasicBlockInContext(ctx.context, lfn, "entry"); LLVMPositionBuilderAtEnd(ctx.builder, entry); // 清空变量表(每个函数独立作用域) ctx.var_table = NULL; // 将参数注册为变量 for (size_t j = 0; j < fn->as.function.param_count; j++) { LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j); LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder, to_llvm_type(&ctx, fn->as.function.params[j]->as.parameter.type), fn->as.function.params[j]->as.parameter.name); LLVMBuildStore(ctx.builder, param, alloca); add_var(&ctx, fn->as.function.params[j]->as.parameter.name, alloca); } codegen_stmt(&ctx, fn->as.function.body); // 确保函数有终止指令(terminator) if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx.builder))) { if (fn->as.function.return_type == TYPE_VOID) LLVMBuildRetVoid(ctx.builder); else LLVMBuildRet(ctx.builder, LLVMConstInt(to_llvm_type(&ctx, fn->as.function.return_type), 0, false)); } } // 验证模块(使用 ReturnStatus 以获取完整错误消息) char* verify_err = NULL; if (LLVMVerifyModule(ctx.module, LLVMReturnStatusAction, &verify_err)) { *error_msg = verify_err ? verify_err : "模块验证失败(错误消息为 NULL)"; LLVMDisposeBuilder(ctx.builder); LLVMContextDispose(ctx.context); return NULL; } LLVMDisposeBuilder(ctx.builder); return ctx.module; }