851 lines
36 KiB
C
851 lines
36 KiB
C
#include "codegen.h"
|
||
#include <llvm-c/Analysis.h>
|
||
#include <llvm-c/Types.h>
|
||
#include <string.h>
|
||
#include <stdio.h>
|
||
|
||
// === 递归深度限制
|
||
static int codegen_depth = 0;
|
||
#define MAX_CODEGEN_DEPTH 1000
|
||
|
||
// === 内部状态 ===
|
||
typedef struct VarEntry {
|
||
const char* name;
|
||
LLVMValueRef alloca;
|
||
LLVMTypeRef alloca_type; // 分配的类型(GEP 需要)
|
||
struct VarEntry* next;
|
||
} VarEntry;
|
||
|
||
typedef struct FnEntry {
|
||
const char* name;
|
||
LLVMValueRef fn;
|
||
TypeKind ret;
|
||
TypeKind* params;
|
||
size_t pc;
|
||
struct FnEntry* next;
|
||
} FnEntry;
|
||
|
||
// 结构体类型映射
|
||
typedef struct StructTypeEntry {
|
||
const char* name;
|
||
LLVMTypeRef llvm_type;
|
||
size_t field_count;
|
||
struct StructTypeEntry* next;
|
||
} StructTypeEntry;
|
||
|
||
typedef struct {
|
||
Arena* arena; // 代码生成阶段分配器
|
||
LLVMContextRef context; // LLVM 19+ 需要显式 Context
|
||
LLVMModuleRef module;
|
||
LLVMBuilderRef builder;
|
||
VarEntry* var_table;
|
||
const char* error;
|
||
FnEntry* fn_table;
|
||
StructTypeEntry* struct_table;
|
||
// printf 运行时支持(内置 print 函数委托给 printf)
|
||
LLVMValueRef printf_fn;
|
||
LLVMTypeRef printf_ty;
|
||
// 字符串拼接运行时支持
|
||
LLVMValueRef malloc_fn;
|
||
LLVMValueRef free_fn; // auto-free 需要 free()
|
||
LLVMValueRef strlen_fn;
|
||
LLVMValueRef memcpy_fn;
|
||
// 自动内存管理: 追踪需要 free 的 str alloca
|
||
LLVMValueRef* cleanup_list;
|
||
size_t cleanup_count;
|
||
size_t cleanup_cap;
|
||
} CgCtx;
|
||
|
||
// === 类型映射(需要 Context)===
|
||
static LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
|
||
switch (kind) {
|
||
case TYPE_I32: return LLVMInt32TypeInContext(ctx->context);
|
||
case TYPE_I64: return LLVMInt64TypeInContext(ctx->context);
|
||
case TYPE_U64: return LLVMInt64TypeInContext(ctx->context);
|
||
case TYPE_F64: return LLVMDoubleTypeInContext(ctx->context);
|
||
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
||
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
||
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
||
case TYPE_STRUCT:
|
||
case TYPE_ENUM: return LLVMInt64TypeInContext(ctx->context);
|
||
case TYPE_UNKNOWN:
|
||
case TYPE_ERROR:
|
||
default: return LLVMVoidTypeInContext(ctx->context);
|
||
}
|
||
}
|
||
|
||
static LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) {
|
||
switch (lit->as.literal.lit_type) {
|
||
case TYPE_I32:
|
||
case TYPE_I64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, true);
|
||
case TYPE_U64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
|
||
case TYPE_CHAR: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
|
||
case TYPE_F64: return LLVMConstReal(ty, lit->as.literal.f64_val);
|
||
case TYPE_BOOL: return LLVMConstInt(ty, lit->as.literal.bool_val ? 1 : 0, false);
|
||
default: return NULL;
|
||
}
|
||
}
|
||
|
||
// === 变量表 ===
|
||
static LLVMValueRef find_var(CgCtx* ctx, const char* name) {
|
||
for (VarEntry* e = ctx->var_table; e; e = e->next)
|
||
if (strcmp(e->name, name) == 0) return e->alloca;
|
||
return NULL;
|
||
}
|
||
|
||
static void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
|
||
VarEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||
if (!e) return;
|
||
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table;
|
||
ctx->var_table = e;
|
||
}
|
||
|
||
// === 函数表 ===
|
||
static LLVMValueRef find_fn(CgCtx* ctx, const char* name) {
|
||
for (FnEntry* e = ctx->fn_table; e; e = e->next)
|
||
if (strcmp(e->name, name) == 0) return e->fn;
|
||
return NULL;
|
||
}
|
||
|
||
static void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) {
|
||
FnEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||
if (!e) return;
|
||
e->name = name; e->fn = fn;
|
||
e->ret = TYPE_VOID;
|
||
e->params = NULL;
|
||
e->pc = 0;
|
||
e->next = ctx->fn_table;
|
||
ctx->fn_table = e;
|
||
}
|
||
|
||
// === 结构体类型表 ===
|
||
static void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) {
|
||
StructTypeEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||
if (!e) return;
|
||
e->name = name; e->llvm_type = ty; e->field_count = fc;
|
||
e->next = ctx->struct_table;
|
||
ctx->struct_table = e;
|
||
}
|
||
|
||
static LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name) {
|
||
for (StructTypeEntry* e = ctx->struct_table; e; e = e->next)
|
||
if (strcmp(e->name, name) == 0) return e->llvm_type;
|
||
return NULL;
|
||
}
|
||
|
||
// 将整数值强制转换到目标 LLVM 类型(sext/zext/trunc)
|
||
static LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val,
|
||
LLVMTypeRef from_ty, LLVMTypeRef to_ty) {
|
||
if (from_ty == to_ty) return val;
|
||
int from_w = LLVMGetIntTypeWidth(from_ty);
|
||
int to_w = LLVMGetIntTypeWidth(to_ty);
|
||
if (from_w < to_w)
|
||
return LLVMBuildSExt(ctx->builder, val, to_ty, "sext");
|
||
else
|
||
return LLVMBuildTrunc(ctx->builder, val, to_ty, "trunc");
|
||
}
|
||
|
||
// 从 TypeInfo 生成 LLVM 类型(支持数组、结构体等复合类型)
|
||
static LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti) {
|
||
switch (ti->kind) {
|
||
case TYPE_ARRAY: {
|
||
TypeInfo elem = { .kind = ti->element_type, .struct_name = ti->element_struct_name };
|
||
LLVMTypeRef elem_ty = type_info_to_llvm(ctx, &elem);
|
||
return LLVMArrayType(elem_ty, (unsigned)ti->array_size);
|
||
}
|
||
case TYPE_STRUCT:
|
||
if (ti->struct_name) {
|
||
LLVMTypeRef st = find_struct_type(ctx, ti->struct_name);
|
||
if (st) return st;
|
||
}
|
||
return LLVMVoidTypeInContext(ctx->context);
|
||
case TYPE_ENUM:
|
||
return LLVMInt64TypeInContext(ctx->context);
|
||
default:
|
||
return to_llvm_type(ctx, ti->kind);
|
||
}
|
||
}
|
||
|
||
// === 向前声明 ===
|
||
static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node);
|
||
static void codegen_stmt(CgCtx* ctx, AstNode* node);
|
||
|
||
// === 表达式代码生成 ===
|
||
static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) {
|
||
if (!node) return NULL;
|
||
|
||
switch (node->kind) {
|
||
case AST_LITERAL_EXPR:
|
||
if (node->type.kind == TYPE_STR) {
|
||
return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str");
|
||
}
|
||
return to_llvm_const(to_llvm_type(ctx, node->type.kind), node);
|
||
|
||
case AST_IDENT_EXPR: {
|
||
LLVMValueRef ptr = find_var(ctx, node->as.ident.name);
|
||
if (!ptr) return NULL;
|
||
LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type);
|
||
return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load");
|
||
}
|
||
|
||
case AST_UNARY_EXPR: {
|
||
LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand);
|
||
if (!operand) return NULL;
|
||
if (node->as.unary.op == OP_NEG) {
|
||
if (node->type.kind == TYPE_F64)
|
||
return LLVMBuildFNeg(ctx->builder, operand, "fneg");
|
||
else
|
||
return LLVMBuildNeg(ctx->builder, operand, "ineg");
|
||
} else {
|
||
return LLVMBuildNot(ctx->builder, operand, "not");
|
||
}
|
||
}
|
||
|
||
case AST_BINARY_EXPR: {
|
||
LLVMValueRef l = codegen_expr(ctx, node->as.binary.left);
|
||
LLVMValueRef r = codegen_expr(ctx, node->as.binary.right);
|
||
if (!l || !r) return NULL;
|
||
|
||
// 字符串拼接:alloc 栈缓冲区,strcpy + strcat
|
||
if (node->type.kind == TYPE_STR) {
|
||
// strlen(left)
|
||
LLVMValueRef len_l = LLVMBuildCall2(ctx->builder,
|
||
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
|
||
(LLVMValueRef[]){l}, 1, "strlen_l");
|
||
// strlen(right)
|
||
LLVMValueRef len_r = LLVMBuildCall2(ctx->builder,
|
||
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
|
||
(LLVMValueRef[]){r}, 1, "strlen_r");
|
||
// total = len_l + len_r + 1
|
||
LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total");
|
||
total = LLVMBuildAdd(ctx->builder, total,
|
||
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1");
|
||
// char* buf = malloc(total)
|
||
LLVMValueRef buf = LLVMBuildCall2(ctx->builder,
|
||
LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn,
|
||
(LLVMValueRef[]){total}, 1, "str_buf");
|
||
// memcpy(buf, left, len_l)
|
||
LLVMBuildCall2(ctx->builder,
|
||
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
|
||
(LLVMValueRef[]){buf, l, len_l}, 3, "");
|
||
// memcpy(buf + len_l, right, len_r + 1) -- includes null terminator
|
||
LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder,
|
||
LLVMInt8TypeInContext(ctx->context), buf,
|
||
(LLVMValueRef[]){len_l}, 1, "offset");
|
||
LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r,
|
||
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1");
|
||
LLVMBuildCall2(ctx->builder,
|
||
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
|
||
(LLVMValueRef[]){offset_ptr, r, len_r1}, 3, "");
|
||
return buf;
|
||
}
|
||
|
||
bool is_float = (node->type.kind == TYPE_F64);
|
||
|
||
switch (node->as.binary.op) {
|
||
case OP_ADD:
|
||
return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd")
|
||
: LLVMBuildAdd(ctx->builder, l, r, "iadd");
|
||
case OP_SUB:
|
||
return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub")
|
||
: LLVMBuildSub(ctx->builder, l, r, "isub");
|
||
case OP_MUL:
|
||
return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul")
|
||
: LLVMBuildMul(ctx->builder, l, r, "imul");
|
||
case OP_DIV:
|
||
return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv")
|
||
: LLVMBuildSDiv(ctx->builder, l, r, "sdiv");
|
||
case OP_MOD:
|
||
return LLVMBuildSRem(ctx->builder, l, r, "srem");
|
||
case OP_EQ:
|
||
return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOEQ, l, r, "feq")
|
||
: LLVMBuildICmp(ctx->builder, LLVMIntEQ, l, r, "ieq");
|
||
case OP_NE:
|
||
return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealONE, l, r, "fne")
|
||
: LLVMBuildICmp(ctx->builder, LLVMIntNE, l, r, "ine");
|
||
case OP_LT:
|
||
return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOLT, l, r, "flt")
|
||
: LLVMBuildICmp(ctx->builder, LLVMIntSLT, l, r, "ilt");
|
||
case OP_GT:
|
||
return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOGT, l, r, "fgt")
|
||
: LLVMBuildICmp(ctx->builder, LLVMIntSGT, l, r, "igt");
|
||
case OP_LE:
|
||
return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOLE, l, r, "fle")
|
||
: LLVMBuildICmp(ctx->builder, LLVMIntSLE, l, r, "ile");
|
||
case OP_GE:
|
||
return is_float ? LLVMBuildFCmp(ctx->builder, LLVMRealOGE, l, r, "fge")
|
||
: LLVMBuildICmp(ctx->builder, LLVMIntSGE, l, r, "ige");
|
||
case OP_AND:
|
||
return LLVMBuildAnd(ctx->builder, l, r, "and");
|
||
case OP_OR:
|
||
return LLVMBuildOr(ctx->builder, l, r, "or");
|
||
default:
|
||
return NULL;
|
||
}
|
||
}
|
||
|
||
case AST_CALL_EXPR: {
|
||
// === 内置 print 函数:委托给 printf ===
|
||
if (strcmp(node->as.call.name, "print_i64") == 0) {
|
||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||
if (!arg) return NULL;
|
||
LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context);
|
||
arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty);
|
||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64");
|
||
LLVMValueRef printf_args[] = { fmt, arg };
|
||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||
printf_args, 2, "");
|
||
}
|
||
if (strcmp(node->as.call.name, "print_f64") == 0) {
|
||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||
if (!arg) return NULL;
|
||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64");
|
||
LLVMValueRef printf_args[] = { fmt, arg };
|
||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||
printf_args, 2, "");
|
||
}
|
||
if (strcmp(node->as.call.name, "print_bool") == 0) {
|
||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||
if (!arg) return NULL;
|
||
// 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择
|
||
LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg,
|
||
LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp");
|
||
LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str");
|
||
LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str");
|
||
LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel");
|
||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||
(LLVMValueRef[]){selected}, 1, "");
|
||
}
|
||
if (strcmp(node->as.call.name, "print_str") == 0) {
|
||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||
if (!arg) return NULL;
|
||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str");
|
||
LLVMValueRef printf_args[] = { fmt, arg };
|
||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||
printf_args, 2, "");
|
||
}
|
||
|
||
// === 常规函数调用 ===
|
||
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
||
if (!fn) return NULL;
|
||
LLVMValueRef args[16];
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||
args[i] = codegen_expr(ctx, node->as.call.args[i]);
|
||
if (!args[i]) return NULL;
|
||
}
|
||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||
return LLVMBuildCall2(ctx->builder, fn_ty, fn,
|
||
args, (unsigned)node->as.call.arg_count,
|
||
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
|
||
}
|
||
|
||
// === 结构体字段访问: p.x ===
|
||
case AST_FIELD_ACCESS: {
|
||
// 对对象求值(返回的是 struct 值)
|
||
LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object);
|
||
if (!struct_val) return NULL;
|
||
|
||
int field_idx = node->as.field_access.field_index;
|
||
if (field_idx < 0) return NULL; // sema 应当已经设置
|
||
|
||
// 用 extractvalue 从结构体值中提取字段
|
||
return LLVMBuildExtractValue(ctx->builder, struct_val,
|
||
(unsigned)field_idx, node->as.field_access.field);
|
||
}
|
||
|
||
// === 结构体初始化: Point { x: 10, y: 20 } ===
|
||
case AST_STRUCT_INIT: {
|
||
const char* st_name = node->as.struct_init.type_name;
|
||
LLVMTypeRef struct_ty = find_struct_type(ctx, st_name);
|
||
if (!struct_ty) return NULL;
|
||
|
||
// alloca 分配结构体空间
|
||
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init");
|
||
|
||
// 获取结构体字段名列表(从 struct_table 或从 AST 中)
|
||
// 对每个 init 字段,找到它在结构体中的索引并 store
|
||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||
AstNode* fval = node->as.struct_init.field_values[i];
|
||
LLVMValueRef val = codegen_expr(ctx, fval);
|
||
if (!val) return NULL;
|
||
|
||
// 获取字段指针: GEP struct_ty, alloca, 0, i
|
||
LLVMValueRef indices[] = {
|
||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false)
|
||
};
|
||
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca,
|
||
indices, 2, "field_ptr");
|
||
LLVMBuildStore(ctx->builder, val, field_ptr);
|
||
}
|
||
|
||
// 加载整个结构体值
|
||
return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val");
|
||
}
|
||
|
||
case AST_ENUM_VARIANT:
|
||
return LLVMConstInt(LLVMInt64TypeInContext(ctx->context),
|
||
(unsigned long long)node->as.enum_variant.variant_index, true);
|
||
|
||
case AST_METHOD_CALL: {
|
||
const char* struct_name = node->as.method_call.receiver->type.struct_name;
|
||
char mangled[256];
|
||
snprintf(mangled, sizeof(mangled), "%s$%s", struct_name,
|
||
node->as.method_call.method_name);
|
||
LLVMValueRef fn = find_fn(ctx, mangled);
|
||
if (!fn) return NULL;
|
||
// 参数列表: [receiver, 用户参数...]
|
||
LLVMValueRef args[16];
|
||
args[0] = codegen_expr(ctx, node->as.method_call.receiver);
|
||
if (!args[0]) return NULL;
|
||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||
args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]);
|
||
if (!args[i + 1]) return NULL;
|
||
}
|
||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
|
||
(unsigned)(node->as.method_call.arg_count + 1),
|
||
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call");
|
||
}
|
||
|
||
case AST_INDEX_EXPR: {
|
||
// 获取数组变量的指针
|
||
AstNode* arr_node = node->as.index_expr.array;
|
||
LLVMValueRef arr_ptr = NULL;
|
||
LLVMTypeRef arr_gp_type = NULL;
|
||
|
||
if (arr_node->kind == AST_IDENT_EXPR) {
|
||
arr_ptr = find_var(ctx, arr_node->as.ident.name);
|
||
// 从变量表获取数组类型用于 GEP
|
||
for (VarEntry* e = ctx->var_table; e; e = e->next) {
|
||
if (strcmp(e->name, arr_node->as.ident.name) == 0) {
|
||
arr_gp_type = e->alloca_type; break;
|
||
}
|
||
}
|
||
}
|
||
if (!arr_ptr || !arr_gp_type) return NULL;
|
||
|
||
// 生成索引值
|
||
LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index);
|
||
if (!idx_val) return NULL;
|
||
|
||
// GEP 索引必须是 i32,但 L 使用 i64。截断。
|
||
LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val,
|
||
LLVMInt32TypeInContext(ctx->context), "idx32");
|
||
|
||
LLVMValueRef indices[] = {
|
||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||
idx_i32
|
||
};
|
||
LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem");
|
||
|
||
LLVMTypeRef elem_load_ty;
|
||
if (node->type.kind == TYPE_STRUCT && node->type.struct_name) {
|
||
elem_load_ty = find_struct_type(ctx, node->type.struct_name);
|
||
if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind);
|
||
} else {
|
||
elem_load_ty = type_info_to_llvm(ctx, &node->type);
|
||
}
|
||
return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load");
|
||
}
|
||
|
||
default:
|
||
return NULL;
|
||
}
|
||
}
|
||
|
||
// === 自动内存管理: 作用域退出时释放 str 堆分配 ===
|
||
static void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) {
|
||
if (ctx->cleanup_count >= ctx->cleanup_cap) {
|
||
size_t new_cap = ctx->cleanup_cap ? ctx->cleanup_cap * 2 : 16;
|
||
LLVMValueRef* new_list = arena_alloc(ctx->arena, new_cap * sizeof(LLVMValueRef));
|
||
if (!new_list) return;
|
||
if (ctx->cleanup_list)
|
||
memcpy(new_list, ctx->cleanup_list, ctx->cleanup_count * sizeof(LLVMValueRef));
|
||
ctx->cleanup_list = new_list;
|
||
ctx->cleanup_cap = new_cap;
|
||
}
|
||
ctx->cleanup_list[ctx->cleanup_count++] = alloca;
|
||
}
|
||
|
||
// 释放从 mark 位置开始的所有 str 变量
|
||
static void cleanup_emit(CgCtx* ctx, size_t from_mark) {
|
||
for (size_t j = from_mark; j < ctx->cleanup_count; j++) {
|
||
LLVMValueRef ptr = ctx->cleanup_list[j];
|
||
LLVMValueRef val = LLVMBuildLoad2(ctx->builder,
|
||
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0), ptr, "free_load");
|
||
LLVMBuildCall2(ctx->builder,
|
||
LLVMGlobalGetValueType(ctx->free_fn), ctx->free_fn,
|
||
(LLVMValueRef[]){val}, 1, "");
|
||
}
|
||
ctx->cleanup_count = from_mark;
|
||
}
|
||
|
||
// === 语句代码生成 ===
|
||
static void codegen_stmt(CgCtx* ctx, AstNode* node) {
|
||
if (!node) return;
|
||
|
||
switch (node->kind) {
|
||
case AST_LET_STMT: {
|
||
// 使用节点的完整类型信息来确定 LLVM 类型
|
||
// 如果 sema 未运行 (node->type.kind == TYPE_UNKNOWN),回退到 init 的类型
|
||
LLVMTypeRef var_type;
|
||
if (node->type.kind == TYPE_UNKNOWN) {
|
||
// 回退到旧行为:使用 init 表达式的类型
|
||
AstNode* init_node = node->as.let_stmt.init;
|
||
if (init_node->type.kind == TYPE_STRUCT && init_node->type.struct_name) {
|
||
var_type = find_struct_type(ctx, init_node->type.struct_name);
|
||
if (!var_type) var_type = to_llvm_type(ctx, init_node->type.kind);
|
||
} else {
|
||
var_type = to_llvm_type(ctx, init_node->type.kind);
|
||
}
|
||
} else {
|
||
var_type = type_info_to_llvm(ctx, &node->type);
|
||
}
|
||
if (!var_type) return;
|
||
|
||
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder,
|
||
var_type, node->as.let_stmt.name);
|
||
|
||
// 尝试生成 init 值;数组类型可能 init 失败 (自引用占位符)
|
||
LLVMValueRef init_val = codegen_expr(ctx, node->as.let_stmt.init);
|
||
if (init_val) {
|
||
// 若 init LLVM 类型与 alloca 类型不同,强制转换(如 i64→i32)
|
||
LLVMTypeRef init_ty = LLVMTypeOf(init_val);
|
||
if (init_ty != var_type && LLVMGetTypeKind(init_ty) == LLVMIntegerTypeKind
|
||
&& LLVMGetTypeKind(var_type) == LLVMIntegerTypeKind) {
|
||
init_val = coerce_int(ctx, init_val, init_ty, var_type);
|
||
}
|
||
LLVMBuildStore(ctx->builder, init_val, alloca);
|
||
} else if (node->type.kind == TYPE_ARRAY) {
|
||
// 数组声明: init 失败是预期的 (自引用), 存储零初始化
|
||
LLVMValueRef zero_init = LLVMConstNull(var_type);
|
||
LLVMBuildStore(ctx->builder, zero_init, alloca);
|
||
} else {
|
||
return;
|
||
}
|
||
add_var(ctx, node->as.let_stmt.name, alloca, var_type);
|
||
|
||
// 自动内存管理: 只追踪 str 堆分配 (拼接/malloc)
|
||
// struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展
|
||
if (node->as.let_stmt.init->type.kind == TYPE_STR) {
|
||
AstKind ik = node->as.let_stmt.init->kind;
|
||
if (ik == AST_BINARY_EXPR || ik == AST_CALL_EXPR) {
|
||
cleanup_add(ctx, alloca);
|
||
}
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_ASSIGN_STMT: {
|
||
LLVMValueRef ptr = find_var(ctx, node->as.assign_stmt.name);
|
||
if (!ptr) return;
|
||
LLVMValueRef val = codegen_expr(ctx, node->as.assign_stmt.value);
|
||
if (!val) return;
|
||
LLVMBuildStore(ctx->builder, val, ptr);
|
||
break;
|
||
}
|
||
|
||
case AST_EXPR_STMT:
|
||
codegen_expr(ctx, node->as.expr_stmt.expr);
|
||
break;
|
||
|
||
case AST_RETURN_STMT: {
|
||
// 先计算返回值
|
||
LLVMValueRef ret_val = NULL;
|
||
bool has_val = node->as.return_stmt.expr != NULL;
|
||
if (has_val) {
|
||
ret_val = codegen_expr(ctx, node->as.return_stmt.expr);
|
||
if (!ret_val) return;
|
||
}
|
||
// 如果返回的是 str 类型的变量,从清理列表移除以防止 use-after-free
|
||
if (has_val && node->as.return_stmt.expr->type.kind == TYPE_STR &&
|
||
node->as.return_stmt.expr->kind == AST_IDENT_EXPR) {
|
||
LLVMValueRef alloca = find_var(ctx, node->as.return_stmt.expr->as.ident.name);
|
||
if (alloca) {
|
||
for (size_t i = 0; i < ctx->cleanup_count; i++) {
|
||
if (ctx->cleanup_list[i] == alloca) {
|
||
ctx->cleanup_list[i] = ctx->cleanup_list[ctx->cleanup_count - 1];
|
||
ctx->cleanup_count--;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// return 前释放当前作用域所有 str 堆分配
|
||
cleanup_emit(ctx, 0);
|
||
// 然后 emit ret
|
||
if (has_val) LLVMBuildRet(ctx->builder, ret_val);
|
||
else LLVMBuildRetVoid(ctx->builder);
|
||
break;
|
||
}
|
||
|
||
case AST_BLOCK: {
|
||
if (++codegen_depth > MAX_CODEGEN_DEPTH) { codegen_depth--; return; }
|
||
size_t block_mark = ctx->cleanup_count;
|
||
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
|
||
codegen_stmt(ctx, node->as.block.stmts[i]);
|
||
}
|
||
cleanup_emit(ctx, block_mark); // 作用域退出: 释放块内 str 堆分配
|
||
codegen_depth--;
|
||
break;
|
||
}
|
||
|
||
case AST_IF_STMT: {
|
||
LLVMValueRef cond = codegen_expr(ctx, node->as.if_stmt.cond);
|
||
if (!cond) return;
|
||
LLVMBasicBlockRef cur_bb = LLVMGetInsertBlock(ctx->builder);
|
||
LLVMValueRef cur_fn = LLVMGetBasicBlockParent(cur_bb);
|
||
LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "then");
|
||
LLVMBasicBlockRef else_bb = node->as.if_stmt.else_block
|
||
? LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "else") : NULL;
|
||
LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "if_merge");
|
||
|
||
if (else_bb)
|
||
LLVMBuildCondBr(ctx->builder, cond, then_bb, else_bb);
|
||
else
|
||
LLVMBuildCondBr(ctx->builder, cond, then_bb, merge_bb);
|
||
|
||
LLVMPositionBuilderAtEnd(ctx->builder, then_bb);
|
||
codegen_stmt(ctx, node->as.if_stmt.then_block);
|
||
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder)))
|
||
LLVMBuildBr(ctx->builder, merge_bb);
|
||
|
||
if (else_bb) {
|
||
LLVMPositionBuilderAtEnd(ctx->builder, else_bb);
|
||
codegen_stmt(ctx, node->as.if_stmt.else_block);
|
||
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder)))
|
||
LLVMBuildBr(ctx->builder, merge_bb);
|
||
}
|
||
|
||
LLVMPositionBuilderAtEnd(ctx->builder, merge_bb);
|
||
break;
|
||
}
|
||
|
||
case AST_ARRAY_ASSIGN_STMT: {
|
||
LLVMValueRef arr_ptr = find_var(ctx, node->as.array_assign.name);
|
||
if (!arr_ptr) return;
|
||
|
||
// 获取数组的 LLVM 类型(从变量表中)
|
||
VarEntry* ve = NULL;
|
||
for (VarEntry* e = ctx->var_table; e; e = e->next)
|
||
if (strcmp(e->name, node->as.array_assign.name) == 0) { ve = e; break; }
|
||
|
||
LLVMValueRef idx_val = codegen_expr(ctx, node->as.array_assign.index);
|
||
if (!idx_val) return;
|
||
|
||
LLVMValueRef val_val = codegen_expr(ctx, node->as.array_assign.value);
|
||
if (!val_val) return;
|
||
|
||
// i64 → i32 截断
|
||
LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val,
|
||
LLVMInt32TypeInContext(ctx->context), "idx32");
|
||
|
||
LLVMValueRef indices[] = {
|
||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||
idx_i32
|
||
};
|
||
LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, ve->alloca_type, arr_ptr, indices, 2, "arr_assign_elem");
|
||
|
||
LLVMBuildStore(ctx->builder, val_val, elem_ptr);
|
||
break;
|
||
}
|
||
|
||
case AST_WHILE_STMT: {
|
||
LLVMBasicBlockRef cur_bb = LLVMGetInsertBlock(ctx->builder);
|
||
LLVMValueRef cur_fn = LLVMGetBasicBlockParent(cur_bb);
|
||
LLVMBasicBlockRef cond_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_cond");
|
||
LLVMBasicBlockRef body_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_body");
|
||
LLVMBasicBlockRef exit_bb = LLVMAppendBasicBlockInContext(ctx->context, cur_fn, "while_exit");
|
||
|
||
LLVMBuildBr(ctx->builder, cond_bb);
|
||
LLVMPositionBuilderAtEnd(ctx->builder, cond_bb);
|
||
LLVMValueRef cond = codegen_expr(ctx, node->as.while_stmt.cond);
|
||
if (!cond) return;
|
||
LLVMBuildCondBr(ctx->builder, cond, body_bb, exit_bb);
|
||
|
||
LLVMPositionBuilderAtEnd(ctx->builder, body_bb);
|
||
codegen_stmt(ctx, node->as.while_stmt.body);
|
||
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx->builder)))
|
||
LLVMBuildBr(ctx->builder, cond_bb);
|
||
|
||
LLVMPositionBuilderAtEnd(ctx->builder, exit_bb);
|
||
break;
|
||
}
|
||
|
||
default:
|
||
break;
|
||
}
|
||
}
|
||
|
||
// === 程序级代码生成 ===
|
||
LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
|
||
const char* name, const char** error_msg,
|
||
LLVMContextRef* out_context) {
|
||
CgCtx ctx = {0};
|
||
ctx.arena = codegen_arena;
|
||
ctx.context = LLVMContextCreate();
|
||
if (!ctx.context) {
|
||
*error_msg = "无法创建 LLVM Context";
|
||
*out_context = NULL;
|
||
return NULL;
|
||
}
|
||
ctx.module = LLVMModuleCreateWithNameInContext(name, ctx.context);
|
||
ctx.builder = LLVMCreateBuilderInContext(ctx.context);
|
||
|
||
// 声明 C 标准库 printf(内置 print 函数依赖它)
|
||
LLVMTypeRef printf_param_types[] = {
|
||
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0)
|
||
};
|
||
ctx.printf_ty = LLVMFunctionType(
|
||
LLVMInt32TypeInContext(ctx.context), printf_param_types, 1, true);
|
||
ctx.printf_fn = LLVMAddFunction(ctx.module, "printf", ctx.printf_ty);
|
||
|
||
// 声明 malloc: void* malloc(size_t)
|
||
LLVMTypeRef malloc_args[] = { LLVMInt64TypeInContext(ctx.context) };
|
||
LLVMTypeRef malloc_ty = LLVMFunctionType(
|
||
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0), malloc_args, 1, false);
|
||
ctx.malloc_fn = LLVMAddFunction(ctx.module, "malloc", malloc_ty);
|
||
|
||
// 声明 free: void free(void*)
|
||
LLVMTypeRef free_args[] = { LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0) };
|
||
LLVMTypeRef free_ty = LLVMFunctionType(LLVMVoidTypeInContext(ctx.context), free_args, 1, false);
|
||
ctx.free_fn = LLVMAddFunction(ctx.module, "free", free_ty);
|
||
|
||
// 声明 strlen: size_t strlen(const char*)
|
||
LLVMTypeRef strlen_args[] = { LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0) };
|
||
LLVMTypeRef strlen_ty = LLVMFunctionType(
|
||
LLVMInt64TypeInContext(ctx.context), strlen_args, 1, false);
|
||
ctx.strlen_fn = LLVMAddFunction(ctx.module, "strlen", strlen_ty);
|
||
|
||
// 声明 memcpy: void* memcpy(void*, const void*, size_t)
|
||
LLVMTypeRef memcpy_args[] = {
|
||
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0),
|
||
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0),
|
||
LLVMInt64TypeInContext(ctx.context),
|
||
};
|
||
LLVMTypeRef memcpy_ty = LLVMFunctionType(
|
||
LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0),
|
||
memcpy_args, 3, false);
|
||
ctx.memcpy_fn = LLVMAddFunction(ctx.module, "memcpy", memcpy_ty);
|
||
|
||
// 第零遍:先创建所有命名结构体(占位符,未设置 body)
|
||
for (size_t i = 0; i < ast->as.program.struct_count; i++) {
|
||
AstNode* sd = ast->as.program.structs[i];
|
||
LLVMTypeRef llvm_st = LLVMStructCreateNamed(ctx.context, sd->as.struct_decl.name);
|
||
add_struct_type(&ctx, sd->as.struct_decl.name, llvm_st,
|
||
sd->as.struct_decl.field_count);
|
||
}
|
||
// 然后设置所有结构体的 body(此时所有结构体类型已注册,可互相引用)
|
||
for (size_t i = 0; i < ast->as.program.struct_count; i++) {
|
||
AstNode* sd = ast->as.program.structs[i];
|
||
LLVMTypeRef llvm_st = find_struct_type(&ctx, sd->as.struct_decl.name);
|
||
LLVMTypeRef* elem_types = arena_alloc(ctx.arena,
|
||
sd->as.struct_decl.field_count * sizeof(LLVMTypeRef));
|
||
for (size_t j = 0; j < sd->as.struct_decl.field_count; j++) {
|
||
AstNode* field = sd->as.struct_decl.fields[j];
|
||
if (field->as.parameter.type == TYPE_STRUCT &&
|
||
field->as.parameter.struct_type_name) {
|
||
elem_types[j] = find_struct_type(&ctx,
|
||
field->as.parameter.struct_type_name);
|
||
} else {
|
||
elem_types[j] = to_llvm_type(&ctx, field->as.parameter.type);
|
||
}
|
||
}
|
||
LLVMStructSetBody(llvm_st, elem_types,
|
||
(unsigned)sd->as.struct_decl.field_count, false);
|
||
}
|
||
|
||
// 第一遍:声明所有 L 函数
|
||
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
|
||
AstNode* fn = ast->as.program.functions[i];
|
||
LLVMTypeRef* ptypes = arena_alloc(ctx.arena,
|
||
fn->as.function.param_count * sizeof(LLVMTypeRef));
|
||
for (size_t j = 0; j < fn->as.function.param_count; j++) {
|
||
AstNode* param = fn->as.function.params[j];
|
||
if (param->as.parameter.type == TYPE_STRUCT &&
|
||
param->as.parameter.struct_type_name) {
|
||
ptypes[j] = find_struct_type(&ctx, param->as.parameter.struct_type_name);
|
||
} else {
|
||
ptypes[j] = to_llvm_type(&ctx, param->as.parameter.type);
|
||
}
|
||
}
|
||
LLVMTypeRef ret_ty;
|
||
if (fn->as.function.return_type == TYPE_STRUCT &&
|
||
fn->as.function.return_struct_type_name) {
|
||
ret_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name);
|
||
} else {
|
||
ret_ty = to_llvm_type(&ctx, fn->as.function.return_type);
|
||
}
|
||
LLVMTypeRef fty = LLVMFunctionType(ret_ty,
|
||
ptypes, (unsigned)fn->as.function.param_count, false);
|
||
LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty);
|
||
add_fn(&ctx, fn->as.function.name, lfn);
|
||
}
|
||
|
||
// 第二遍:生成函数体
|
||
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
|
||
AstNode* fn = ast->as.program.functions[i];
|
||
LLVMValueRef lfn = find_fn(&ctx, fn->as.function.name);
|
||
LLVMBasicBlockRef entry = LLVMAppendBasicBlockInContext(ctx.context, lfn, "entry");
|
||
LLVMPositionBuilderAtEnd(ctx.builder, entry);
|
||
|
||
// 清空变量表(每个函数独立作用域)
|
||
ctx.var_table = NULL;
|
||
|
||
// 将参数注册为变量
|
||
for (size_t j = 0; j < fn->as.function.param_count; j++) {
|
||
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j);
|
||
AstNode* pnode = fn->as.function.params[j];
|
||
LLVMTypeRef param_ty;
|
||
if (pnode->as.parameter.type == TYPE_STRUCT &&
|
||
pnode->as.parameter.struct_type_name) {
|
||
param_ty = find_struct_type(&ctx, pnode->as.parameter.struct_type_name);
|
||
} else {
|
||
param_ty = to_llvm_type(&ctx, pnode->as.parameter.type);
|
||
}
|
||
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
|
||
param_ty, pnode->as.parameter.name);
|
||
LLVMBuildStore(ctx.builder, param, alloca);
|
||
add_var(&ctx, pnode->as.parameter.name, alloca, param_ty);
|
||
}
|
||
|
||
codegen_stmt(&ctx, fn->as.function.body);
|
||
|
||
// 确保函数有终止指令(terminator)
|
||
if (!LLVMGetBasicBlockTerminator(LLVMGetInsertBlock(ctx.builder))) {
|
||
// 函数结尾隐式 return: 先释放所有 str 堆分配
|
||
cleanup_emit(&ctx, 0);
|
||
if (fn->as.function.return_type == TYPE_VOID)
|
||
LLVMBuildRetVoid(ctx.builder);
|
||
else if (fn->as.function.return_type == TYPE_STRUCT &&
|
||
fn->as.function.return_struct_type_name) {
|
||
LLVMTypeRef st_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name);
|
||
LLVMBuildRet(ctx.builder, st_ty ? LLVMConstNull(st_ty) :
|
||
LLVMConstInt(to_llvm_type(&ctx, TYPE_I64), 0, false));
|
||
}
|
||
else
|
||
LLVMBuildRet(ctx.builder,
|
||
(fn->as.function.return_type == TYPE_F64
|
||
? LLVMConstReal(to_llvm_type(&ctx, TYPE_F64), 0.0)
|
||
: LLVMConstInt(to_llvm_type(&ctx, fn->as.function.return_type), 0, false)));
|
||
}
|
||
}
|
||
|
||
// 验证模块(使用 ReturnStatus 以获取完整错误消息)
|
||
char* verify_err = NULL;
|
||
if (LLVMVerifyModule(ctx.module, LLVMReturnStatusAction, &verify_err)) {
|
||
*error_msg = verify_err ? verify_err
|
||
: arena_strdup(ctx.arena, "LLVM 模块验证失败");
|
||
LLVMDisposeBuilder(ctx.builder);
|
||
*out_context = ctx.context;
|
||
return NULL;
|
||
}
|
||
|
||
LLVMDisposeBuilder(ctx.builder);
|
||
*out_context = ctx.context;
|
||
return ctx.module;
|
||
}
|