Files
l-language/src/codegen/codegen.c
T

942 lines
40 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "codegen.h"
#include <llvm-c/Analysis.h>
#include <llvm-c/Types.h>
#include <string.h>
#include <stdio.h>
// === 递归深度限制
static int codegen_depth = 0;
#define MAX_CODEGEN_DEPTH 1000
// === 内部状态 ===
typedef struct VarEntry {
const char* name;
LLVMValueRef alloca;
LLVMTypeRef alloca_type; // 分配的类型(GEP 需要)
struct VarEntry* next;
} VarEntry;
typedef struct FnEntry {
const char* name;
LLVMValueRef fn;
TypeKind ret;
TypeKind* params;
size_t pc;
struct FnEntry* next;
} FnEntry;
// 结构体类型映射
typedef struct StructTypeEntry {
const char* name;
LLVMTypeRef llvm_type;
size_t field_count;
struct StructTypeEntry* next;
} StructTypeEntry;
typedef struct {
Arena* arena; // 代码生成阶段分配器
LLVMContextRef context; // LLVM 19+ 需要显式 Context
LLVMModuleRef module;
LLVMBuilderRef builder;
VarEntry* var_table;
const char* error;
FnEntry* fn_table;
StructTypeEntry* struct_table;
// printf 运行时支持(内置 print 函数委托给 printf
LLVMValueRef printf_fn;
LLVMTypeRef printf_ty;
// 字符串拼接运行时支持
LLVMValueRef malloc_fn;
LLVMValueRef free_fn; // auto-free 需要 free()
LLVMValueRef strlen_fn;
LLVMValueRef memcpy_fn;
// 自动内存管理: 追踪需要 free 的 str alloca
LLVMValueRef* cleanup_list;
size_t cleanup_count;
size_t cleanup_cap;
} CgCtx;
// === 类型映射(需要 Context===
static LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
switch (kind) {
case TYPE_I32: return LLVMInt32TypeInContext(ctx->context);
case TYPE_I64: return LLVMInt64TypeInContext(ctx->context);
case TYPE_U64: return LLVMInt64TypeInContext(ctx->context);
case TYPE_F64: return LLVMDoubleTypeInContext(ctx->context);
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
case TYPE_STRUCT:
case TYPE_ENUM: {
// tagged union: { i64 tag, i64 payload }
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
LLVMInt64TypeInContext(ctx->context) };
return LLVMStructTypeInContext(ctx->context, fields, 2, false);
}
case TYPE_UNKNOWN:
case TYPE_ERROR:
default: return LLVMVoidTypeInContext(ctx->context);
}
}
static LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) {
switch (lit->as.literal.lit_type) {
case TYPE_I32:
case TYPE_I64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, true);
case TYPE_U64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
case TYPE_CHAR: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
case TYPE_F64: return LLVMConstReal(ty, lit->as.literal.f64_val);
case TYPE_BOOL: return LLVMConstInt(ty, lit->as.literal.bool_val ? 1 : 0, false);
default: return NULL;
}
}
// === 变量表 ===
static LLVMValueRef find_var(CgCtx* ctx, const char* name) {
for (VarEntry* e = ctx->var_table; e; e = e->next)
if (strcmp(e->name, name) == 0) return e->alloca;
return NULL;
}
static void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
VarEntry* e = arena_alloc(ctx->arena, sizeof(*e));
if (!e) return;
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table;
ctx->var_table = e;
}
// === 函数表 ===
static LLVMValueRef find_fn(CgCtx* ctx, const char* name) {
for (FnEntry* e = ctx->fn_table; e; e = e->next)
if (strcmp(e->name, name) == 0) return e->fn;
return NULL;
}
static void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) {
FnEntry* e = arena_alloc(ctx->arena, sizeof(*e));
if (!e) return;
e->name = name; e->fn = fn;
e->ret = TYPE_VOID;
e->params = NULL;
e->pc = 0;
e->next = ctx->fn_table;
ctx->fn_table = e;
}
// === 结构体类型表 ===
static void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) {
StructTypeEntry* e = arena_alloc(ctx->arena, sizeof(*e));
if (!e) return;
e->name = name; e->llvm_type = ty; e->field_count = fc;
e->next = ctx->struct_table;
ctx->struct_table = e;
}
static LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name) {
for (StructTypeEntry* e = ctx->struct_table; e; e = e->next)
if (strcmp(e->name, name) == 0) return e->llvm_type;
return NULL;
}
// 将整数值强制转换到目标 LLVM 类型(sext/zext/trunc
static LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val,
LLVMTypeRef from_ty, LLVMTypeRef to_ty) {
if (from_ty == to_ty) return val;
int from_w = LLVMGetIntTypeWidth(from_ty);
int to_w = LLVMGetIntTypeWidth(to_ty);
if (from_w < to_w)
return LLVMBuildSExt(ctx->builder, val, to_ty, "sext");
else
return LLVMBuildTrunc(ctx->builder, val, to_ty, "trunc");
}
// 从 TypeInfo 生成 LLVM 类型(支持数组、结构体等复合类型)
static LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti) {
switch (ti->kind) {
case TYPE_ARRAY: {
TypeInfo elem = { .kind = ti->element_type, .struct_name = ti->element_struct_name };
LLVMTypeRef elem_ty = type_info_to_llvm(ctx, &elem);
return LLVMArrayType(elem_ty, (unsigned)ti->array_size);
}
case TYPE_STRUCT:
if (ti->struct_name) {
LLVMTypeRef st = find_struct_type(ctx, ti->struct_name);
if (st) return st;
}
return LLVMVoidTypeInContext(ctx->context);
case TYPE_ENUM: {
LLVMTypeRef f[] = { LLVMInt64TypeInContext(ctx->context),
LLVMInt64TypeInContext(ctx->context) };
return LLVMStructTypeInContext(ctx->context, f, 2, false);
}
default:
return to_llvm_type(ctx, ti->kind);
}
}
// === 向前声明 ===
static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node);
static void codegen_stmt(CgCtx* ctx, AstNode* node);
// === 表达式代码生成 ===
static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) {
if (!node) return NULL;
switch (node->kind) {
case AST_LITERAL_EXPR:
if (node->type.kind == TYPE_STR) {
return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str");
}
return to_llvm_const(to_llvm_type(ctx, node->type.kind), node);
case AST_IDENT_EXPR: {
LLVMValueRef ptr = find_var(ctx, node->as.ident.name);
if (!ptr) return NULL;
LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type);
return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load");
}
case AST_UNARY_EXPR: {
LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand);
if (!operand) return NULL;
if (node->as.unary.op == OP_NEG) {
if (node->type.kind == TYPE_F64)
return LLVMBuildFNeg(ctx->builder, operand, "fneg");
else
return LLVMBuildNeg(ctx->builder, operand, "ineg");
} else {
return LLVMBuildNot(ctx->builder, operand, "not");
}
}
case AST_BINARY_EXPR: {
LLVMValueRef l = codegen_expr(ctx, node->as.binary.left);
LLVMValueRef r = codegen_expr(ctx, node->as.binary.right);
if (!l || !r) return NULL;
// 字符串拼接:alloc 栈缓冲区,strcpy + strcat
if (node->type.kind == TYPE_STR) {
// strlen(left)
LLVMValueRef len_l = LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
(LLVMValueRef[]){l}, 1, "strlen_l");
// strlen(right)
LLVMValueRef len_r = LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
(LLVMValueRef[]){r}, 1, "strlen_r");
// total = len_l + len_r + 1
LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total");
total = LLVMBuildAdd(ctx->builder, total,
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1");
// char* buf = malloc(total)
LLVMValueRef buf = LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn,
(LLVMValueRef[]){total}, 1, "str_buf");
// memcpy(buf, left, len_l)
LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
(LLVMValueRef[]){buf, l, len_l}, 3, "");
// memcpy(buf + len_l, right, len_r + 1) -- includes null terminator
LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder,
LLVMInt8TypeInContext(ctx->context), buf,
(LLVMValueRef[]){len_l}, 1, "offset");
LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r,
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1");
LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
(LLVMValueRef[]){offset_ptr, r, len_r1}, 3, "");
return buf;
}
bool is_float = (node->type.kind == TYPE_F64);
switch (node->as.binary.op) {
case OP_ADD:
return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd")
: LLVMBuildAdd(ctx->builder, l, r, "iadd");
case OP_SUB:
return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub")
: LLVMBuildSub(ctx->builder, l, r, "isub");
case OP_MUL:
return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul")
: LLVMBuildMul(ctx->builder, l, r, "imul");
case OP_DIV:
return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv")
: LLVMBuildSDiv(ctx->builder, l, r, "sdiv");
case OP_MOD:
return LLVMBuildSRem(ctx->builder, l, r, "srem");
case OP_EQ:
case OP_NE:
case OP_LT:
case OP_GT:
case OP_LE:
case OP_GE: {
// 枚举比较: 提取 tag 字段再比较
LLVMValueRef cl = l, cr = r;
if (node->as.binary.left->type.kind == TYPE_ENUM) {
cl = LLVMBuildExtractValue(ctx->builder, l, 0, "enum_tag_l");
}
if (node->as.binary.right->type.kind == TYPE_ENUM) {
cr = LLVMBuildExtractValue(ctx->builder, r, 0, "enum_tag_r");
}
LLVMIntPredicate pred;
switch (node->as.binary.op) {
case OP_EQ: pred = LLVMIntEQ; break;
case OP_NE: pred = LLVMIntNE; break;
case OP_LT: pred = LLVMIntSLT; break;
case OP_GT: pred = LLVMIntSGT; break;
case OP_LE: pred = LLVMIntSLE; break;
case OP_GE: pred = LLVMIntSGE; break;
default: return NULL;
}
if (is_float)
return LLVMBuildFCmp(ctx->builder, pred == LLVMIntEQ ? LLVMRealOEQ :
pred == LLVMIntNE ? LLVMRealONE : pred == LLVMIntSLT ? LLVMRealOLT :
pred == LLVMIntSGT ? LLVMRealOGT : pred == LLVMIntSLE ? LLVMRealOLE :
LLVMRealOGE, cl, cr, "fcmp");
return LLVMBuildICmp(ctx->builder, pred, cl, cr, "icmp");
}
case OP_AND:
return LLVMBuildAnd(ctx->builder, l, r, "and");
case OP_OR:
return LLVMBuildOr(ctx->builder, l, r, "or");
default:
return NULL;
}
}
case AST_CALL_EXPR: {
// === 内置 print 函数:委托给 printf ===
if (strcmp(node->as.call.name, "print_i64") == 0) {
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
if (!arg) return NULL;
// 枚举类型: 提取 tag 字段
if (node->as.call.args[0]->type.kind == TYPE_ENUM)
arg = LLVMBuildExtractValue(ctx->builder, arg, 0, "tag");
LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context);
arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty);
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64");
LLVMValueRef printf_args[] = { fmt, arg };
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
printf_args, 2, "");
}
if (strcmp(node->as.call.name, "print_f64") == 0) {
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
if (!arg) return NULL;
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64");
LLVMValueRef printf_args[] = { fmt, arg };
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
printf_args, 2, "");
}
if (strcmp(node->as.call.name, "print_bool") == 0) {
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
if (!arg) return NULL;
// 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择
LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg,
LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp");
LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str");
LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str");
LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel");
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
(LLVMValueRef[]){selected}, 1, "");
}
if (strcmp(node->as.call.name, "print_str") == 0) {
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
if (!arg) return NULL;
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str");
LLVMValueRef printf_args[] = { fmt, arg };
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
printf_args, 2, "");
}
// === 常规函数调用 ===
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
if (!fn) return NULL;
LLVMValueRef args[16];
for (size_t i = 0; i < node->as.call.arg_count; i++) {
args[i] = codegen_expr(ctx, node->as.call.args[i]);
if (!args[i]) return NULL;
}
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
return LLVMBuildCall2(ctx->builder, fn_ty, fn,
args, (unsigned)node->as.call.arg_count,
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
}
// === 结构体字段访问: p.x ===
case AST_FIELD_ACCESS: {
// 对对象求值(返回的是 struct 值)
LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object);
if (!struct_val) return NULL;
int field_idx = node->as.field_access.field_index;
if (field_idx < 0) return NULL; // sema 应当已经设置
// 用 extractvalue 从结构体值中提取字段
return LLVMBuildExtractValue(ctx->builder, struct_val,
(unsigned)field_idx, node->as.field_access.field);
}
// === 结构体初始化: Point { x: 10, y: 20 } ===
case AST_STRUCT_INIT: {
const char* st_name = node->as.struct_init.type_name;
LLVMTypeRef struct_ty = find_struct_type(ctx, st_name);
if (!struct_ty) return NULL;
// alloca 分配结构体空间
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init");
// 获取结构体字段名列表(从 struct_table 或从 AST 中)
// 对每个 init 字段,找到它在结构体中的索引并 store
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
AstNode* fval = node->as.struct_init.field_values[i];
LLVMValueRef val = codegen_expr(ctx, fval);
if (!val) return NULL;
// 获取字段指针: GEP struct_ty, alloca, 0, i
LLVMValueRef indices[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false)
};
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca,
indices, 2, "field_ptr");
LLVMBuildStore(ctx->builder, val, field_ptr);
}
// 加载整个结构体值
return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val");
}
case AST_ENUM_VARIANT: {
// tagged union: { tag, payload }
LLVMValueRef tag = LLVMConstInt(LLVMInt64TypeInContext(ctx->context),
(unsigned long long)node->as.enum_variant.variant_index, true);
LLVMValueRef payload = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 0, true);
if (node->as.enum_variant.payload) {
LLVMValueRef pv = codegen_expr(ctx, node->as.enum_variant.payload);
if (pv) {
// 将 payload 强制转换为 i64
LLVMTypeRef pv_ty = LLVMTypeOf(pv);
LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context);
if (pv_ty != i64_ty && LLVMGetTypeKind(pv_ty) == LLVMIntegerTypeKind)
pv = coerce_int(ctx, pv, pv_ty, i64_ty);
payload = pv;
}
}
LLVMValueRef fields[] = { tag, payload };
return LLVMConstStruct(fields, 2, false);
}
case AST_METHOD_CALL: {
const char* struct_name = node->as.method_call.receiver->type.struct_name;
char mangled[256];
snprintf(mangled, sizeof(mangled), "%s$%s", struct_name,
node->as.method_call.method_name);
LLVMValueRef fn = find_fn(ctx, mangled);
if (!fn) return NULL;
// 参数列表: [receiver, 用户参数...]
LLVMValueRef args[16];
args[0] = codegen_expr(ctx, node->as.method_call.receiver);
if (!args[0]) return NULL;
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]);
if (!args[i + 1]) return NULL;
}
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
(unsigned)(node->as.method_call.arg_count + 1),
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call");
}
case AST_INDEX_EXPR: {
// 获取数组变量的指针
AstNode* arr_node = node->as.index_expr.array;
LLVMValueRef arr_ptr = NULL;
LLVMTypeRef arr_gp_type = NULL;
if (arr_node->kind == AST_IDENT_EXPR) {
arr_ptr = find_var(ctx, arr_node->as.ident.name);
// 从变量表获取数组类型用于 GEP
for (VarEntry* e = ctx->var_table; e; e = e->next) {
if (strcmp(e->name, arr_node->as.ident.name) == 0) {
arr_gp_type = e->alloca_type; break;
}
}
}
if (!arr_ptr || !arr_gp_type) return NULL;
// 生成索引值
LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index);
if (!idx_val) return NULL;
// GEP 索引必须是 i32,但 L 使用 i64。截断。
LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val,
LLVMInt32TypeInContext(ctx->context), "idx32");
LLVMValueRef indices[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
idx_i32
};
LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem");
LLVMTypeRef elem_load_ty;
if (node->type.kind == TYPE_STRUCT && node->type.struct_name) {
elem_load_ty = find_struct_type(ctx, node->type.struct_name);
if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind);
} else {
elem_load_ty = type_info_to_llvm(ctx, &node->type);
}
return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load");
}
// 块表达式: { stmt*; expr } → 最后表达式的值
case AST_BLOCK: {
LLVMValueRef result = NULL;
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
AstNode* stmt = node->as.block.stmts[i];
bool is_last = (i == node->as.block.stmt_count - 1);
if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) {
result = codegen_expr(ctx, stmt->as.expr_stmt.expr);
} else {
codegen_stmt(ctx, stmt);
}
}
return result;
}
// if 表达式: if cond { a } else { b }
case AST_IF_STMT: {
if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; }
LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond);
if (!cond_val) return NULL;
LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type);
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res");
LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder));
LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then");
LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else");
LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge");
LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb);
LLVMPositionBuilderAtEnd(ctx->builder, then_bb);
LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block);
if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca);
LLVMBuildBr(ctx->builder, merge_bb);
LLVMPositionBuilderAtEnd(ctx->builder, else_bb);
LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block);
if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca);
LLVMBuildBr(ctx->builder, merge_bb);
LLVMPositionBuilderAtEnd(ctx->builder, merge_bb);
return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val");
}
default:
return NULL;
}
}
// === 自动内存管理: 作用域退出时释放 str 堆分配 ===
static void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) {
if (ctx->cleanup_count >= ctx->cleanup_cap) {
size_t new_cap = ctx->cleanup_cap ? ctx->cleanup_cap * 2 : 16;
LLVMValueRef* new_list = arena_alloc(ctx->arena, new_cap * sizeof(LLVMValueRef));
if (!new_list) return;
if (ctx->cleanup_list)
memcpy(new_list, ctx->cleanup_list, ctx->cleanup_count * sizeof(LLVMValueRef));
ctx->cleanup_list = new_list;
ctx->cleanup_cap = new_cap;
}
ctx->cleanup_list[ctx->cleanup_count++] = alloca;
}
// 释放从 mark 位置开始的所有 str 变量
static void cleanup_emit(CgCtx* ctx, size_t from_mark) {
for (size_t j = from_mark; j < ctx->cleanup_count; j++) {
LLVMValueRef ptr = ctx->cleanup_list[j];
LLVMValueRef val = LLVMBuildLoad2(ctx->builder,
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0), ptr, "free_load");
LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->free_fn), ctx->free_fn,
(LLVMValueRef[]){val}, 1, "");
}
ctx->cleanup_count = from_mark;
}
// === 语句代码生成 ===
static void codegen_stmt(CgCtx* ctx, AstNode* node) {
if (!node) return;
switch (node->kind) {
case AST_LET_STMT: {
// 使用节点的完整类型信息来确定 LLVM 类型
// 如果 sema 未运行 (node->type.kind == TYPE_UNKNOWN),回退到 init 的类型
LLVMTypeRef var_type;
if (node->type.kind == TYPE_UNKNOWN) {
// 回退到旧行为:使用 init 表达式的类型
AstNode* init_node = node->as.let_stmt.init;
if (init_node->type.kind == TYPE_STRUCT && init_node->type.struct_name) {
var_type = find_struct_type(ctx, init_node->type.struct_name);
if (!var_type) var_type = to_llvm_type(ctx, init_node->type.kind);
} else {
var_type = to_llvm_type(ctx, init_node->type.kind);
}
} else {
var_type = type_info_to_llvm(ctx, &node->type);
}
if (!var_type) return;
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder,
var_type, node->as.let_stmt.name);
// 尝试生成 init 值;数组类型可能 init 失败 (自引用占位符)
LLVMValueRef init_val = codegen_expr(ctx, node->as.let_stmt.init);
if (init_val) {
// 若 init LLVM 类型与 alloca 类型不同,强制转换(如 i64→i32)
LLVMTypeRef init_ty = LLVMTypeOf(init_val);
if (init_ty != var_type && LLVMGetTypeKind(init_ty) == LLVMIntegerTypeKind
&& LLVMGetTypeKind(var_type) == LLVMIntegerTypeKind) {
init_val = coerce_int(ctx, init_val, init_ty, var_type);
}
LLVMBuildStore(ctx->builder, init_val, alloca);
} else if (node->type.kind == TYPE_ARRAY) {
// 数组声明: init 失败是预期的 (自引用), 存储零初始化
LLVMValueRef zero_init = LLVMConstNull(var_type);
LLVMBuildStore(ctx->builder, zero_init, alloca);
} else {
return;
}
add_var(ctx, node->as.let_stmt.name, alloca, var_type);
// 自动内存管理: 只追踪 str 堆分配 (拼接/malloc)
// struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展
if (node->as.let_stmt.init->type.kind == TYPE_STR) {
AstKind ik = node->as.let_stmt.init->kind;
if (ik == AST_BINARY_EXPR || ik == AST_CALL_EXPR) {
cleanup_add(ctx, alloca);
}
}
break;
}
case AST_ASSIGN_STMT: {
LLVMValueRef ptr = find_var(ctx, node->as.assign_stmt.name);
if (!ptr) return;
LLVMValueRef val = codegen_expr(ctx, node->as.assign_stmt.value);
if (!val) return;
LLVMBuildStore(ctx->builder, val, ptr);
break;
}
case AST_EXPR_STMT:
codegen_expr(ctx, node->as.expr_stmt.expr);
break;
case AST_RETURN_STMT: {
// 先计算返回值
LLVMValueRef ret_val = NULL;
bool has_val = node->as.return_stmt.expr != NULL;
if (has_val) {
ret_val = codegen_expr(ctx, node->as.return_stmt.expr);
if (!ret_val) return;
}
// 如果返回的是 str 类型的变量,从清理列表移除以防止 use-after-free
if (has_val && node->as.return_stmt.expr->type.kind == TYPE_STR &&
node->as.return_stmt.expr->kind == AST_IDENT_EXPR) {
LLVMValueRef alloca = find_var(ctx, node->as.return_stmt.expr->as.ident.name);
if (alloca) {
for (size_t i = 0; i < ctx->cleanup_count; i++) {
if (ctx->cleanup_list[i] == alloca) {
ctx->cleanup_list[i] = ctx->cleanup_list[ctx->cleanup_count - 1];
ctx->cleanup_count--;
break;
}
}
}
}
// return 前释放当前作用域所有 str 堆分配
cleanup_emit(ctx, 0);
// 然后 emit ret
if (has_val) LLVMBuildRet(ctx->builder, ret_val);
else LLVMBuildRetVoid(ctx->builder);
break;
}
case AST_BLOCK: {
if (++codegen_depth > MAX_CODEGEN_DEPTH) { codegen_depth--; return; }
size_t block_mark = ctx->cleanup_count;
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
codegen_stmt(ctx, node->as.block.stmts[i]);
}
cleanup_emit(ctx, block_mark); // 作用域退出: 释放块内 str 堆分配
codegen_depth--;
break;
}
case AST_IF_STMT: {
LLVMValueRef cond = codegen_expr(ctx, node->as.if_stmt.cond);
if (!cond) return;
LLVMBasicBlockRef cur_bb = LLVMGetInsertBlock(ctx->builder);
LLVMValueRef cur_fn = LLVMGetBasicBlockParent(cur_bb);
LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "then");
LLVMBasicBlockRef else_bb = node->as.if_stmt.else_block
? LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "else") : NULL;
LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "if_merge");
if (else_bb)
LLVMBuildCondBr(ctx->builder, cond, then_bb, else_bb);
else
LLVMBuildCondBr(ctx->builder, cond, then_bb, merge_bb);
LLVMPositionBuilderAtEnd(ctx->builder, then_bb);
codegen_stmt(ctx, node->as.if_stmt.then_block);
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder)))
LLVMBuildBr(ctx->builder, merge_bb);
if (else_bb) {
LLVMPositionBuilderAtEnd(ctx->builder, else_bb);
codegen_stmt(ctx, node->as.if_stmt.else_block);
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder)))
LLVMBuildBr(ctx->builder, merge_bb);
}
LLVMPositionBuilderAtEnd(ctx->builder, merge_bb);
break;
}
case AST_ARRAY_ASSIGN_STMT: {
LLVMValueRef arr_ptr = find_var(ctx, node->as.array_assign.name);
if (!arr_ptr) return;
// 获取数组的 LLVM 类型(从变量表中)
VarEntry* ve = NULL;
for (VarEntry* e = ctx->var_table; e; e = e->next)
if (strcmp(e->name, node->as.array_assign.name) == 0) { ve = e; break; }
LLVMValueRef idx_val = codegen_expr(ctx, node->as.array_assign.index);
if (!idx_val) return;
LLVMValueRef val_val = codegen_expr(ctx, node->as.array_assign.value);
if (!val_val) return;
// i64 → i32 截断
LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val,
LLVMInt32TypeInContext(ctx->context), "idx32");
LLVMValueRef indices[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
idx_i32
};
LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, ve->alloca_type, arr_ptr, indices, 2, "arr_assign_elem");
LLVMBuildStore(ctx->builder, val_val, elem_ptr);
break;
}
case AST_WHILE_STMT: {
LLVMBasicBlockRef cur_bb = LLVMGetInsertBlock(ctx->builder);
LLVMValueRef cur_fn = LLVMGetBasicBlockParent(cur_bb);
LLVMBasicBlockRef cond_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_cond");
LLVMBasicBlockRef body_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_body");
LLVMBasicBlockRef exit_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_exit");
LLVMBuildBr(ctx->builder, cond_bb);
LLVMPositionBuilderAtEnd(ctx->builder, cond_bb);
LLVMValueRef cond = codegen_expr(ctx, node->as.while_stmt.cond);
if (!cond) return;
LLVMBuildCondBr(ctx->builder, cond, body_bb, exit_bb);
LLVMPositionBuilderAtEnd(ctx->builder, body_bb);
codegen_stmt(ctx, node->as.while_stmt.body);
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder)))
LLVMBuildBr(ctx->builder, cond_bb);
LLVMPositionBuilderAtEnd(ctx->builder, exit_bb);
break;
}
default:
break;
}
}
// === 程序级代码生成 ===
LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
const char* name, const char** error_msg,
LLVMContextRef* out_context) {
CgCtx ctx = {0};
ctx.arena = codegen_arena;
ctx.context = LLVMContextCreate();
if (!ctx.context) {
*error_msg = "无法创建 LLVM Context";
*out_context = NULL;
return NULL;
}
ctx.module = LLVMModuleCreateWithNameInContext(name, ctx.context);
ctx.builder = LLVMCreateBuilderInContext(ctx.context);
// 声明 C 标准库 printf(内置 print 函数依赖它)
LLVMTypeRef printf_param_types[] = {
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0)
};
ctx.printf_ty = LLVMFunctionType(
LLVMInt32TypeInContext(ctx.context), printf_param_types, 1, true);
ctx.printf_fn = LLVMAddFunction(ctx.module, "printf", ctx.printf_ty);
// 声明 malloc: void* malloc(size_t)
LLVMTypeRef malloc_args[] = { LLVMInt64TypeInContext(ctx.context) };
LLVMTypeRef malloc_ty = LLVMFunctionType(
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0), malloc_args, 1, false);
ctx.malloc_fn = LLVMAddFunction(ctx.module, "malloc", malloc_ty);
// 声明 free: void free(void*)
LLVMTypeRef free_args[] = { LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0) };
LLVMTypeRef free_ty = LLVMFunctionType(LLVMVoidTypeInContext(ctx.context), free_args, 1, false);
ctx.free_fn = LLVMAddFunction(ctx.module, "free", free_ty);
// 声明 strlen: size_t strlen(const char*)
LLVMTypeRef strlen_args[] = { LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0) };
LLVMTypeRef strlen_ty = LLVMFunctionType(
LLVMInt64TypeInContext(ctx.context), strlen_args, 1, false);
ctx.strlen_fn = LLVMAddFunction(ctx.module, "strlen", strlen_ty);
// 声明 memcpy: void* memcpy(void*, const void*, size_t)
LLVMTypeRef memcpy_args[] = {
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0),
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0),
LLVMInt64TypeInContext(ctx.context),
};
LLVMTypeRef memcpy_ty = LLVMFunctionType(
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0),
memcpy_args, 3, false);
ctx.memcpy_fn = LLVMAddFunction(ctx.module, "memcpy", memcpy_ty);
// __chkstk 桩:LLVM 在生成大栈帧代码时会引用此符号(MinGW x64: __chkstk
{
LLVMTypeRef chkstk_ty = LLVMFunctionType(LLVMVoidTypeInContext(ctx.context), NULL, 0, false);
LLVMValueRef chkstk_fn = LLVMAddFunction(ctx.module, "__chkstk", chkstk_ty);
LLVMBasicBlockRef chk_bb = LLVMAppendBasicBlockInContext(ctx.context, chkstk_fn, "entry");
LLVMPositionBuilderAtEnd(ctx.builder, chk_bb);
LLVMBuildRetVoid(ctx.builder);
}
// 第零遍:先创建所有命名结构体(占位符,未设置 body)
for (size_t i = 0; i < ast->as.program.struct_count; i++) {
AstNode* sd = ast->as.program.structs[i];
LLVMTypeRef llvm_st = LLVMStructCreateNamed(ctx.context, sd->as.struct_decl.name);
add_struct_type(&ctx, sd->as.struct_decl.name, llvm_st,
sd->as.struct_decl.field_count);
}
// 然后设置所有结构体的 body(此时所有结构体类型已注册,可互相引用)
for (size_t i = 0; i < ast->as.program.struct_count; i++) {
AstNode* sd = ast->as.program.structs[i];
LLVMTypeRef llvm_st = find_struct_type(&ctx, sd->as.struct_decl.name);
LLVMTypeRef* elem_types = arena_alloc(ctx.arena,
sd->as.struct_decl.field_count * sizeof(LLVMTypeRef));
for (size_t j = 0; j < sd->as.struct_decl.field_count; j++) {
AstNode* field = sd->as.struct_decl.fields[j];
if (field->as.parameter.type == TYPE_STRUCT &&
field->as.parameter.struct_type_name) {
elem_types[j] = find_struct_type(&ctx,
field->as.parameter.struct_type_name);
} else {
elem_types[j] = to_llvm_type(&ctx, field->as.parameter.type);
}
}
LLVMStructSetBody(llvm_st, elem_types,
(unsigned)sd->as.struct_decl.field_count, false);
}
// 第一遍:声明所有 L 函数
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
AstNode* fn = ast->as.program.functions[i];
LLVMTypeRef* ptypes = arena_alloc(ctx.arena,
fn->as.function.param_count * sizeof(LLVMTypeRef));
for (size_t j = 0; j < fn->as.function.param_count; j++) {
AstNode* param = fn->as.function.params[j];
if (param->as.parameter.type == TYPE_STRUCT &&
param->as.parameter.struct_type_name) {
ptypes[j] = find_struct_type(&ctx, param->as.parameter.struct_type_name);
} else {
ptypes[j] = to_llvm_type(&ctx, param->as.parameter.type);
}
}
LLVMTypeRef ret_ty;
if (fn->as.function.return_type == TYPE_STRUCT &&
fn->as.function.return_struct_type_name) {
ret_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name);
} else {
ret_ty = to_llvm_type(&ctx, fn->as.function.return_type);
}
LLVMTypeRef fty = LLVMFunctionType(ret_ty,
ptypes, (unsigned)fn->as.function.param_count, false);
LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty);
add_fn(&ctx, fn->as.function.name, lfn);
}
// 第二遍:生成函数体
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
AstNode* fn = ast->as.program.functions[i];
LLVMValueRef lfn = find_fn(&ctx, fn->as.function.name);
LLVMBasicBlockRef entry = LLVMAppendBasicBlockInContext(ctx.context, lfn, "entry");
LLVMPositionBuilderAtEnd(ctx.builder, entry);
// 清空变量表(每个函数独立作用域)
ctx.var_table = NULL;
// 将参数注册为变量
for (size_t j = 0; j < fn->as.function.param_count; j++) {
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j);
AstNode* pnode = fn->as.function.params[j];
LLVMTypeRef param_ty;
if (pnode->as.parameter.type == TYPE_STRUCT &&
pnode->as.parameter.struct_type_name) {
param_ty = find_struct_type(&ctx, pnode->as.parameter.struct_type_name);
} else {
param_ty = to_llvm_type(&ctx, pnode->as.parameter.type);
}
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
param_ty, pnode->as.parameter.name);
LLVMBuildStore(ctx.builder, param, alloca);
add_var(&ctx, pnode->as.parameter.name, alloca, param_ty);
}
codegen_stmt(&ctx, fn->as.function.body);
// 确保函数有终止指令(terminator)
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx.builder))) {
// 函数结尾隐式 return: 先释放所有 str 堆分配
cleanup_emit(&ctx, 0);
if (fn->as.function.return_type == TYPE_VOID)
LLVMBuildRetVoid(ctx.builder);
else if (fn->as.function.return_type == TYPE_STRUCT &&
fn->as.function.return_struct_type_name) {
LLVMTypeRef st_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name);
LLVMBuildRet(ctx.builder, st_ty ? LLVMConstNull(st_ty) :
LLVMConstInt(to_llvm_type(&ctx, TYPE_I64), 0, false));
}
else
LLVMBuildRet(ctx.builder,
(fn->as.function.return_type == TYPE_F64
? LLVMConstReal(to_llvm_type(&ctx, TYPE_F64), 0.0)
: LLVMConstInt(to_llvm_type(&ctx, fn->as.function.return_type), 0, false)));
}
}
// 验证模块(使用 ReturnStatus 以获取完整错误消息)
char* verify_err = NULL;
if (LLVMVerifyModule(ctx.module, LLVMReturnStatusAction, &verify_err)) {
*error_msg = verify_err ? verify_err
: arena_strdup(ctx.arena, "LLVM 模块验证失败");
LLVMDisposeBuilder(ctx.builder);
*out_context = ctx.context;
return NULL;
}
LLVMDisposeBuilder(ctx.builder);
*out_context = ctx.context;
return ctx.module;
}