diff --git a/src/codegen/cg_expr.c b/src/codegen/cg_expr.c index 9ea47fe..4a4204f 100644 --- a/src/codegen/cg_expr.c +++ b/src/codegen/cg_expr.c @@ -1,5 +1,7 @@ #include "codegen_internal.h" +#include "visit.h" +// === 类型映射 (保持不变) === LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { switch (kind) { case TYPE_I32: return LLVMInt32TypeInContext(ctx->context); @@ -11,14 +13,14 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0); case TYPE_STRUCT: case TYPE_ENUM: { - // tagged union: { i64 tag, i64 payload } LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context), LLVMInt64TypeInContext(ctx->context) }; return LLVMStructTypeInContext(ctx->context, fields, 2, false); } + case TYPE_VOID: case TYPE_UNKNOWN: case TYPE_ERROR: - default: return LLVMVoidTypeInContext(ctx->context); + default: return LLVMVoidTypeInContext(ctx->context); } } @@ -34,8 +36,9 @@ LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) { } } +// === coerce_int + type_info_to_llvm (保持) === LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, - LLVMTypeRef from_ty, LLVMTypeRef to_ty) { + LLVMTypeRef from_ty, LLVMTypeRef to_ty) { if (from_ty == to_ty) return val; int from_w = LLVMGetIntTypeWidth(from_ty); int to_w = LLVMGetIntTypeWidth(to_ty); @@ -45,396 +48,379 @@ LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, return LLVMBuildTrunc(ctx->builder, val, to_ty, "trunc"); } -// 从 TypeInfo 生成 LLVM 类型(支持数组、结构体等复合类型) LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti) { switch (ti->kind) { - case TYPE_ARRAY: { - TypeInfo elem = { .kind = ti->element_type, .struct_name = ti->element_struct_name }; - LLVMTypeRef elem_ty = type_info_to_llvm(ctx, &elem); - return LLVMArrayType(elem_ty, (unsigned)ti->array_size); + case TYPE_ARRAY: { + LLVMTypeRef elem; + if (ti->element_struct_name) { + LLVMTypeRef st = find_struct_type(ctx, ti->element_struct_name); + elem = st ? st : LLVMInt64TypeInContext(ctx->context); + } else { + elem = to_llvm_type(ctx, ti->element_type); } - case TYPE_STRUCT: - if (ti->struct_name) { - LLVMTypeRef st = find_struct_type(ctx, ti->struct_name); - if (st) return st; - } - return LLVMVoidTypeInContext(ctx->context); - case TYPE_ENUM: { - LLVMTypeRef f[] = { LLVMInt64TypeInContext(ctx->context), - LLVMInt64TypeInContext(ctx->context) }; - return LLVMStructTypeInContext(ctx->context, f, 2, false); + return LLVMArrayType(elem, (unsigned)ti->array_size); + } + case TYPE_STRUCT: + if (ti->struct_name) { + LLVMTypeRef st = find_struct_type(ctx, ti->struct_name); + if (st) return st; } - default: - return to_llvm_type(ctx, ti->kind); + return LLVMVoidTypeInContext(ctx->context); + case TYPE_ENUM: { + LLVMTypeRef f[] = { LLVMInt64TypeInContext(ctx->context), + LLVMInt64TypeInContext(ctx->context) }; + return LLVMStructTypeInContext(ctx->context, f, 2, false); + } + default: + return to_llvm_type(ctx, ti->kind); } } -// === 向前声明 === +// === 前向声明 === LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node); void codegen_stmt(CgCtx* ctx, AstNode* node); -LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { - if (!node) return NULL; - - switch (node->kind) { - case AST_LITERAL_EXPR: - if (node->type.kind == TYPE_STR) { - return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str"); - } - return to_llvm_const(to_llvm_type(ctx, node->type.kind), node); - - case AST_IDENT_EXPR: { - LLVMValueRef ptr = find_var(ctx, node->as.ident.name); - if (!ptr) return NULL; - LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type); - return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load"); +// === Visitor Handler 包装 === +#define CG_HANDLER(name) \ + static void* name(void* vctx, AstNode* node) { \ + return (void*)name##_impl((CgCtx*)vctx, node); \ } - case AST_UNARY_EXPR: { - LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand); - if (!operand) return NULL; - if (node->as.unary.op == OP_NEG) { - if (node->type.kind == TYPE_F64) - return LLVMBuildFNeg(ctx->builder, operand, "fneg"); - else - return LLVMBuildNeg(ctx->builder, operand, "ineg"); - } else { - return LLVMBuildNot(ctx->builder, operand, "not"); - } - } +#define CG_CTX CgCtx* ctx = (CgCtx*)vctx +#define CG_RET(v) return (void*)(v) - case AST_BINARY_EXPR: { - LLVMValueRef l = codegen_expr(ctx, node->as.binary.left); - LLVMValueRef r = codegen_expr(ctx, node->as.binary.right); - if (!l || !r) return NULL; +// === 各表达式 handler (从原 switch 提取) === - // 字符串拼接:alloc 栈缓冲区,strcpy + strcat - if (node->type.kind == TYPE_STR) { - // strlen(left) - LLVMValueRef len_l = LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, - (LLVMValueRef[]){l}, 1, "strlen_l"); - // strlen(right) - LLVMValueRef len_r = LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, - (LLVMValueRef[]){r}, 1, "strlen_r"); - // total = len_l + len_r + 1 - LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total"); - total = LLVMBuildAdd(ctx->builder, total, - LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1"); - // char* buf = malloc(total) - LLVMValueRef buf = LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn, - (LLVMValueRef[]){total}, 1, "str_buf"); - // memcpy(buf, left, len_l) - LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, - (LLVMValueRef[]){buf, l, len_l}, 3, ""); - // memcpy(buf + len_l, right, len_r + 1) -- includes null terminator - LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder, - LLVMInt8TypeInContext(ctx->context), buf, - (LLVMValueRef[]){len_l}, 1, "offset"); - LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r, - LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1"); - LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, - (LLVMValueRef[]){offset_ptr, r, len_r1}, 3, ""); - return buf; - } +static LLVMValueRef cg_literal_impl(CgCtx* ctx, AstNode* node) { + if (node->type.kind == TYPE_STR) + return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str"); + return to_llvm_const(to_llvm_type(ctx, node->type.kind), node); +} +CG_HANDLER(cg_literal) - bool is_float = (node->type.kind == TYPE_F64); +static LLVMValueRef cg_ident_impl(CgCtx* ctx, AstNode* node) { + LLVMValueRef ptr = find_var(ctx, node->as.ident.name); + if (!ptr) return NULL; + LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type); + return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load"); +} +CG_HANDLER(cg_ident) - switch (node->as.binary.op) { - case OP_ADD: - return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd") - : LLVMBuildAdd(ctx->builder, l, r, "iadd"); - case OP_SUB: - return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub") - : LLVMBuildSub(ctx->builder, l, r, "isub"); - case OP_MUL: - return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul") - : LLVMBuildMul(ctx->builder, l, r, "imul"); - case OP_DIV: - return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv") - : LLVMBuildSDiv(ctx->builder, l, r, "sdiv"); - case OP_MOD: - return LLVMBuildSRem(ctx->builder, l, r, "srem"); - case OP_EQ: - case OP_NE: - case OP_LT: - case OP_GT: - case OP_LE: - case OP_GE: { - // 枚举比较: 提取 tag 字段再比较 - LLVMValueRef cl = l, cr = r; - if (node->as.binary.left->type.kind == TYPE_ENUM) { - cl = LLVMBuildExtractValue(ctx->builder, l, 0, "enum_tag_l"); - } - if (node->as.binary.right->type.kind == TYPE_ENUM) { - cr = LLVMBuildExtractValue(ctx->builder, r, 0, "enum_tag_r"); - } - LLVMIntPredicate pred; - switch (node->as.binary.op) { - case OP_EQ: pred = LLVMIntEQ; break; - case OP_NE: pred = LLVMIntNE; break; - case OP_LT: pred = LLVMIntSLT; break; - case OP_GT: pred = LLVMIntSGT; break; - case OP_LE: pred = LLVMIntSLE; break; - case OP_GE: pred = LLVMIntSGE; break; - default: return NULL; - } - if (is_float) - return LLVMBuildFCmp(ctx->builder, pred == LLVMIntEQ ? LLVMRealOEQ : - pred == LLVMIntNE ? LLVMRealONE : pred == LLVMIntSLT ? LLVMRealOLT : - pred == LLVMIntSGT ? LLVMRealOGT : pred == LLVMIntSLE ? LLVMRealOLE : - LLVMRealOGE, cl, cr, "fcmp"); - return LLVMBuildICmp(ctx->builder, pred, cl, cr, "icmp"); - } - case OP_AND: - return LLVMBuildAnd(ctx->builder, l, r, "and"); - case OP_OR: - return LLVMBuildOr(ctx->builder, l, r, "or"); - default: - return NULL; - } - } - - case AST_CALL_EXPR: { - // === 内置 print 函数:委托给 printf === - if (strcmp(node->as.call.name, "print_i64") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - // 枚举类型: 提取 tag 字段 - if (node->as.call.args[0]->type.kind == TYPE_ENUM) - arg = LLVMBuildExtractValue(ctx->builder, arg, 0, "tag"); - LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); - arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty); - LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64"); - LLVMValueRef printf_args[] = { fmt, arg }; - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - printf_args, 2, ""); - } - if (strcmp(node->as.call.name, "print_f64") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64"); - LLVMValueRef printf_args[] = { fmt, arg }; - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - printf_args, 2, ""); - } - if (strcmp(node->as.call.name, "print_bool") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - // 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择 - LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg, - LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp"); - LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str"); - LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str"); - LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel"); - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - (LLVMValueRef[]){selected}, 1, ""); - } - if (strcmp(node->as.call.name, "print_str") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str"); - LLVMValueRef printf_args[] = { fmt, arg }; - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - printf_args, 2, ""); - } - - // === 常规函数调用 === - LLVMValueRef fn = find_fn(ctx, node->as.call.name); - if (!fn) return NULL; - LLVMValueRef args[16]; - if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; } - for (size_t i = 0; i < node->as.call.arg_count; i++) { - args[i] = codegen_expr(ctx, node->as.call.args[i]); - if (!args[i]) return NULL; - } - LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); - LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); - return LLVMBuildCall2(ctx->builder, fn_ty, fn, - args, (unsigned)node->as.call.arg_count, - ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call"); - } - - // === 结构体字段访问: p.x === - case AST_FIELD_ACCESS: { - // 对对象求值(返回的是 struct 值) - LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object); - if (!struct_val) return NULL; - - int field_idx = node->as.field_access.field_index; - if (field_idx < 0) return NULL; // sema 应当已经设置 - - // 用 extractvalue 从结构体值中提取字段 - return LLVMBuildExtractValue(ctx->builder, struct_val, - (unsigned)field_idx, node->as.field_access.field); - } - - // === 结构体初始化: Point { x: 10, y: 20 } === - case AST_STRUCT_INIT: { - const char* st_name = node->as.struct_init.type_name; - LLVMTypeRef struct_ty = find_struct_type(ctx, st_name); - if (!struct_ty) return NULL; - - // alloca 分配结构体空间 - LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init"); - - // 获取结构体字段名列表(从 struct_table 或从 AST 中) - // 对每个 init 字段,找到它在结构体中的索引并 store - for (size_t i = 0; i < node->as.struct_init.field_count; i++) { - AstNode* fval = node->as.struct_init.field_values[i]; - LLVMValueRef val = codegen_expr(ctx, fval); - if (!val) return NULL; - - // 获取字段指针: GEP struct_ty, alloca, 0, i - LLVMValueRef indices[] = { - LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), - LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false) - }; - LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca, - indices, 2, "field_ptr"); - LLVMBuildStore(ctx->builder, val, field_ptr); - } - - // 加载整个结构体值 - return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val"); - } - - case AST_ENUM_VARIANT: { - // tagged union: { tag, payload } - LLVMValueRef tag = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), - (unsigned long long)node->as.enum_variant.variant_index, true); - LLVMValueRef payload = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 0, true); - if (node->as.enum_variant.payload) { - LLVMValueRef pv = codegen_expr(ctx, node->as.enum_variant.payload); - if (pv) { - // 将 payload 强制转换为 i64 - LLVMTypeRef pv_ty = LLVMTypeOf(pv); - LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); - if (pv_ty != i64_ty && LLVMGetTypeKind(pv_ty) == LLVMIntegerTypeKind) - pv = coerce_int(ctx, pv, pv_ty, i64_ty); - payload = pv; - } - } - LLVMValueRef fields[] = { tag, payload }; - return LLVMConstStruct(fields, 2, false); - } - - case AST_METHOD_CALL: { - const char* struct_name = node->as.method_call.receiver->type.struct_name; - char mangled[256]; - // 若 method_name 已含 $(trait 方法,sema 已设置全限定名),直接用 - if (strchr(node->as.method_call.method_name, '$')) - snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name); +static LLVMValueRef cg_unary_impl(CgCtx* ctx, AstNode* node) { + LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand); + if (!operand) return NULL; + if (node->as.unary.op == OP_NEG) { + if (node->type.kind == TYPE_F64) + return LLVMBuildFNeg(ctx->builder, operand, "fneg"); else - snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, - node->as.method_call.method_name); - LLVMValueRef fn = find_fn(ctx, mangled); - if (!fn) return NULL; - // 参数列表: [receiver, 用户参数...] - if (node->as.method_call.arg_count + 1 > 16) { ctx->error = "方法参数过多(最多15)"; return NULL; } - LLVMValueRef args[16]; - args[0] = codegen_expr(ctx, node->as.method_call.receiver); - if (!args[0]) return NULL; - for (size_t i = 0; i < node->as.method_call.arg_count; i++) { - args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]); - if (!args[i + 1]) return NULL; - } - LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); - LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); - return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, - (unsigned)(node->as.method_call.arg_count + 1), - ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call"); + return LLVMBuildNeg(ctx->builder, operand, "ineg"); + } + return LLVMBuildNot(ctx->builder, operand, "not"); +} +CG_HANDLER(cg_unary) + +static LLVMValueRef cg_binary_impl(CgCtx* ctx, AstNode* node) { + LLVMValueRef l = codegen_expr(ctx, node->as.binary.left); + LLVMValueRef r = codegen_expr(ctx, node->as.binary.right); + if (!l || !r) return NULL; + + if (node->type.kind == TYPE_STR) { + LLVMValueRef len_l = LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, + (LLVMValueRef[]){l}, 1, "strlen_l"); + LLVMValueRef len_r = LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, + (LLVMValueRef[]){r}, 1, "strlen_r"); + LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total"); + total = LLVMBuildAdd(ctx->builder, total, + LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1"); + LLVMValueRef buf = LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn, + (LLVMValueRef[]){total}, 1, "str_buf"); + LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, + (LLVMValueRef[]){buf, l, len_l}, 3, ""); + LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder, + LLVMInt8TypeInContext(ctx->context), buf, + (LLVMValueRef[]){len_l}, 1, "offset"); + LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r, + LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1"); + LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, + (LLVMValueRef[]){offset_ptr, r, len_r1}, 3, ""); + return buf; } - case AST_INDEX_EXPR: { - // 获取数组变量的指针 - AstNode* arr_node = node->as.index_expr.array; - LLVMValueRef arr_ptr = NULL; - LLVMTypeRef arr_gp_type = NULL; - - if (arr_node->kind == AST_IDENT_EXPR) { - arr_ptr = find_var(ctx, arr_node->as.ident.name); - // 从变量表获取数组类型用于 GEP - for (VarEntry* e = ctx->var_table; e; e = e->next) { - if (strcmp(e->name, arr_node->as.ident.name) == 0) { - arr_gp_type = e->alloca_type; break; - } + bool is_float = (node->type.kind == TYPE_F64); + switch (node->as.binary.op) { + case OP_ADD: + return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd") + : LLVMBuildAdd(ctx->builder, l, r, "iadd"); + case OP_SUB: + return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub") + : LLVMBuildSub(ctx->builder, l, r, "isub"); + case OP_MUL: + return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul") + : LLVMBuildMul(ctx->builder, l, r, "imul"); + case OP_DIV: + return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv") + : LLVMBuildSDiv(ctx->builder, l, r, "sdiv"); + case OP_MOD: + return LLVMBuildSRem(ctx->builder, l, r, "srem"); + case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE: { + LLVMValueRef cl = l, cr = r; + if (node->as.binary.left->type.kind == TYPE_ENUM) + cl = LLVMBuildExtractValue(ctx->builder, l, 0, "enum_tag_l"); + if (node->as.binary.right->type.kind == TYPE_ENUM) + cr = LLVMBuildExtractValue(ctx->builder, r, 0, "enum_tag_r"); + LLVMIntPredicate pred; + switch (node->as.binary.op) { + case OP_EQ: pred = LLVMIntEQ; break; + case OP_NE: pred = LLVMIntNE; break; + case OP_LT: pred = LLVMIntSLT; break; + case OP_GT: pred = LLVMIntSGT; break; + case OP_LE: pred = LLVMIntSLE; break; + case OP_GE: pred = LLVMIntSGE; break; + default: return NULL; } + if (is_float) + return LLVMBuildFCmp(ctx->builder, pred == LLVMIntEQ ? LLVMRealOEQ : + pred == LLVMIntNE ? LLVMRealONE : pred == LLVMIntSLT ? LLVMRealOLT : + pred == LLVMIntSGT ? LLVMRealOGT : pred == LLVMIntSLE ? LLVMRealOLE : + LLVMRealOGE, cl, cr, "fcmp"); + return LLVMBuildICmp(ctx->builder, pred, cl, cr, "icmp"); } - if (!arr_ptr || !arr_gp_type) return NULL; - - // 生成索引值 - LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index); - if (!idx_val) return NULL; - - // GEP 索引必须是 i32,但 L 使用 i64。截断。 - LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val, - LLVMInt32TypeInContext(ctx->context), "idx32"); - - LLVMValueRef indices[] = { - LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), - idx_i32 - }; - LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem"); - - LLVMTypeRef elem_load_ty; - if (node->type.kind == TYPE_STRUCT && node->type.struct_name) { - elem_load_ty = find_struct_type(ctx, node->type.struct_name); - if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind); - } else { - elem_load_ty = type_info_to_llvm(ctx, &node->type); - } - return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load"); - } - - // 块表达式: { stmt*; expr } → 最后表达式的值 - case AST_BLOCK: { - LLVMValueRef result = NULL; - for (size_t i = 0; i < node->as.block.stmt_count; i++) { - AstNode* stmt = node->as.block.stmts[i]; - bool is_last = (i == node->as.block.stmt_count - 1); - if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) { - result = codegen_expr(ctx, stmt->as.expr_stmt.expr); - } else { - codegen_stmt(ctx, stmt); - } - } - return result; - } - - // if 表达式: if cond { a } else { b } - case AST_IF_STMT: { - if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; } - LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond); - if (!cond_val) return NULL; - LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type); - LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res"); - LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder)); - LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then"); - LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else"); - LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge"); - LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb); - - LLVMPositionBuilderAtEnd(ctx->builder, then_bb); - LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block); - if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca); - LLVMBuildBr(ctx->builder, merge_bb); - - LLVMPositionBuilderAtEnd(ctx->builder, else_bb); - LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block); - if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca); - LLVMBuildBr(ctx->builder, merge_bb); - - LLVMPositionBuilderAtEnd(ctx->builder, merge_bb); - return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val"); - } - - default: - return NULL; + case OP_AND: return LLVMBuildAnd(ctx->builder, l, r, "and"); + case OP_OR: return LLVMBuildOr(ctx->builder, l, r, "or"); + default: return NULL; } } +CG_HANDLER(cg_binary) +static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) { + // 内置 print 委托给 printf + if (strcmp(node->as.call.name, "print_i64") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + if (node->as.call.args[0]->type.kind == TYPE_ENUM) + arg = LLVMBuildExtractValue(ctx->builder, arg, 0, "tag"); + LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); + arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty); + LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64"); + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + (LLVMValueRef[]){fmt, arg}, 2, ""); + } + if (strcmp(node->as.call.name, "print_f64") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64"); + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + (LLVMValueRef[]){fmt, arg}, 2, ""); + } + if (strcmp(node->as.call.name, "print_bool") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg, + LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp"); + LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str"); + LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str"); + LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel"); + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + (LLVMValueRef[]){selected}, 1, ""); + } + if (strcmp(node->as.call.name, "print_str") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str"); + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + (LLVMValueRef[]){fmt, arg}, 2, ""); + } + LLVMValueRef fn = find_fn(ctx, node->as.call.name); + if (!fn) return NULL; + LLVMValueRef args[16]; + if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; } + for (size_t i = 0; i < node->as.call.arg_count; i++) { + args[i] = codegen_expr(ctx, node->as.call.args[i]); + if (!args[i]) return NULL; + } + LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); + LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); + return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, + (unsigned)node->as.call.arg_count, + ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call"); +} +CG_HANDLER(cg_call) + +static LLVMValueRef cg_field_access_impl(CgCtx* ctx, AstNode* node) { + LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object); + if (!struct_val) return NULL; + int field_idx = node->as.field_access.field_index; + if (field_idx < 0) return NULL; + return LLVMBuildExtractValue(ctx->builder, struct_val, + (unsigned)field_idx, node->as.field_access.field); +} +CG_HANDLER(cg_field_access) + +static LLVMValueRef cg_struct_init_impl(CgCtx* ctx, AstNode* node) { + const char* st_name = node->as.struct_init.type_name; + LLVMTypeRef struct_ty = find_struct_type(ctx, st_name); + if (!struct_ty) return NULL; + LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init"); + for (size_t i = 0; i < node->as.struct_init.field_count; i++) { + AstNode* fval = node->as.struct_init.field_values[i]; + LLVMValueRef val = codegen_expr(ctx, fval); + if (!val) return NULL; + LLVMValueRef indices[] = { + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false) + }; + LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca, + indices, 2, "field_ptr"); + LLVMBuildStore(ctx->builder, val, field_ptr); + } + return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val"); +} +CG_HANDLER(cg_struct_init) + +static LLVMValueRef cg_enum_variant_impl(CgCtx* ctx, AstNode* node) { + LLVMValueRef tag = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), + (unsigned long long)node->as.enum_variant.variant_index, true); + LLVMValueRef payload = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 0, true); + if (node->as.enum_variant.payload) { + LLVMValueRef pv = codegen_expr(ctx, node->as.enum_variant.payload); + if (pv) { + LLVMTypeRef pv_ty = LLVMTypeOf(pv); + LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); + if (pv_ty != i64_ty && LLVMGetTypeKind(pv_ty) == LLVMIntegerTypeKind) + pv = coerce_int(ctx, pv, pv_ty, i64_ty); + payload = pv; + } + } + LLVMValueRef fields[] = { tag, payload }; + return LLVMConstStruct(fields, 2, false); +} +CG_HANDLER(cg_enum_variant) + +static LLVMValueRef cg_method_call_impl(CgCtx* ctx, AstNode* node) { + const char* struct_name = node->as.method_call.receiver->type.struct_name; + char mangled[256]; + if (strchr(node->as.method_call.method_name, '$')) + snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name); + else + snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, + node->as.method_call.method_name); + LLVMValueRef fn = find_fn(ctx, mangled); + if (!fn) return NULL; + if (node->as.method_call.arg_count + 1 > 16) { ctx->error = "方法参数过多(最多15)"; return NULL; } + LLVMValueRef args[16]; + args[0] = codegen_expr(ctx, node->as.method_call.receiver); + if (!args[0]) return NULL; + for (size_t i = 0; i < node->as.method_call.arg_count; i++) { + args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]); + if (!args[i + 1]) return NULL; + } + LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); + LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); + return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, + (unsigned)(node->as.method_call.arg_count + 1), + ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call"); +} +CG_HANDLER(cg_method_call) + +static LLVMValueRef cg_index_impl(CgCtx* ctx, AstNode* node) { + AstNode* arr_node = node->as.index_expr.array; + LLVMValueRef arr_ptr = NULL; + LLVMTypeRef arr_gp_type = NULL; + if (arr_node->kind == AST_IDENT_EXPR) { + arr_ptr = find_var(ctx, arr_node->as.ident.name); + for (VarEntry* e = ctx->var_table; e; e = e->next) { + if (strcmp(e->name, arr_node->as.ident.name) == 0) { + arr_gp_type = e->alloca_type; break; + } + } + } + if (!arr_ptr || !arr_gp_type) return NULL; + LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index); + if (!idx_val) return NULL; + LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val, + LLVMInt32TypeInContext(ctx->context), "idx32"); + LLVMValueRef indices[] = { + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), idx_i32 + }; + LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem"); + LLVMTypeRef elem_load_ty; + if (node->type.kind == TYPE_STRUCT && node->type.struct_name) { + elem_load_ty = find_struct_type(ctx, node->type.struct_name); + if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind); + } else { + elem_load_ty = type_info_to_llvm(ctx, &node->type); + } + return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load"); +} +CG_HANDLER(cg_index) + +static LLVMValueRef cg_block_impl(CgCtx* ctx, AstNode* node) { + LLVMValueRef result = NULL; + for (size_t i = 0; i < node->as.block.stmt_count; i++) { + AstNode* stmt = node->as.block.stmts[i]; + bool is_last = (i == node->as.block.stmt_count - 1); + if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) + result = codegen_expr(ctx, stmt->as.expr_stmt.expr); + else + codegen_stmt(ctx, stmt); + } + return result; +} +CG_HANDLER(cg_block) + +static LLVMValueRef cg_if_expr_impl(CgCtx* ctx, AstNode* node) { + if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; } + LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond); + if (!cond_val) return NULL; + LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type); + LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res"); + LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder)); + LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then"); + LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else"); + LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge"); + LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, then_bb); + LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block); + if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca); + LLVMBuildBr(ctx->builder, merge_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, else_bb); + LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block); + if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca); + LLVMBuildBr(ctx->builder, merge_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, merge_bb); + return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val"); +} +CG_HANDLER(cg_if_expr) + +// === Visitor Dispatch 表 === +static AstDispatch cg_dispatch; + +void codegen_expr_init(void) { + ast_dispatch_set(&cg_dispatch, AST_LITERAL_EXPR, cg_literal); + ast_dispatch_set(&cg_dispatch, AST_IDENT_EXPR, cg_ident); + ast_dispatch_set(&cg_dispatch, AST_UNARY_EXPR, cg_unary); + ast_dispatch_set(&cg_dispatch, AST_BINARY_EXPR, cg_binary); + ast_dispatch_set(&cg_dispatch, AST_CALL_EXPR, cg_call); + ast_dispatch_set(&cg_dispatch, AST_FIELD_ACCESS, cg_field_access); + ast_dispatch_set(&cg_dispatch, AST_STRUCT_INIT, cg_struct_init); + ast_dispatch_set(&cg_dispatch, AST_ENUM_VARIANT, cg_enum_variant); + ast_dispatch_set(&cg_dispatch, AST_METHOD_CALL, cg_method_call); + ast_dispatch_set(&cg_dispatch, AST_INDEX_EXPR, cg_index); + ast_dispatch_set(&cg_dispatch, AST_BLOCK, cg_block); + ast_dispatch_set(&cg_dispatch, AST_IF_STMT, cg_if_expr); +} + +// === 统一入口 === +LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { + if (!node) return NULL; + cg_dispatch.ctx = ctx; + return (LLVMValueRef)ast_visit(&cg_dispatch, node); +} diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index ea5a849..417fc6b 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -288,6 +288,7 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, } ctx.module = LLVMModuleCreateWithNameInContext(name, ctx.context); ctx.builder = LLVMCreateBuilderInContext(ctx.context); + codegen_expr_init(); // 声明 C 标准库 printf(内置 print 函数依赖它) LLVMTypeRef printf_param_types[] = { diff --git a/src/codegen/codegen_internal.h b/src/codegen/codegen_internal.h index f87aebd..757ba07 100644 --- a/src/codegen/codegen_internal.h +++ b/src/codegen/codegen_internal.h @@ -5,6 +5,7 @@ #include "ast.h" #include "arena.h" #include "l_lang.h" +#include "visit.h" #include #include #include @@ -77,6 +78,7 @@ void cleanup_add(CgCtx* ctx, LLVMValueRef alloca); void cleanup_emit(CgCtx* ctx, size_t from_mark); // === 代码生成函数 === +void codegen_expr_init(void); LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node); void codegen_stmt(CgCtx* ctx, AstNode* node);