From 6d1db585c4aeb093dd7b265b70b53de048f81653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 6 Jun 2026 19:26:54 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20sema.c=20+=20codegen.c=20=E6=8B=86?= =?UTF-8?q?=E5=88=86=EF=BC=8C=E5=85=A8=E9=83=A8=E6=BA=90=E6=96=87=E4=BB=B6?= =?UTF-8?q?=20<800=20=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sema.c 1129行 → sema.c 499行 + typeck.c 629行 + sema_internal.h 51行 - typeck.c: 表达式类型检查 (10个analyze_*函数) + 泛型单态化 + 类型关系 - sema.c: analyze_node + sema_analyze codegen.c 947行 → codegen.c 453行 + cg_expr.c 440行 + codegen_internal.h 83行 - cg_expr.c: LLVM表达式生成 + 类型映射 (to_llvm_type/coerce_int/type_info_to_llvm) - codegen.c: 语句生成 + 模块入口 + 符号表 + 内存清理 全部核心源文件 <800 行限制: parser(662+498), sema(499+629), codegen(453+440) Co-Authored-By: Claude Opus 4.7 --- src/codegen/cg_expr.c | 440 +++++++++++++++++++++++ src/codegen/codegen.c | 516 +-------------------------- src/codegen/codegen_internal.h | 83 +++++ src/sema/sema.c | 634 +-------------------------------- src/sema/sema_internal.h | 51 +++ src/sema/typeck.c | 629 ++++++++++++++++++++++++++++++++ 6 files changed, 1216 insertions(+), 1137 deletions(-) create mode 100644 src/codegen/cg_expr.c create mode 100644 src/codegen/codegen_internal.h create mode 100644 src/sema/sema_internal.h create mode 100644 src/sema/typeck.c diff --git a/src/codegen/cg_expr.c b/src/codegen/cg_expr.c new file mode 100644 index 0000000..9ea47fe --- /dev/null +++ b/src/codegen/cg_expr.c @@ -0,0 +1,440 @@ +#include "codegen_internal.h" + +LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { + switch (kind) { + case TYPE_I32: return LLVMInt32TypeInContext(ctx->context); + case TYPE_I64: return LLVMInt64TypeInContext(ctx->context); + case TYPE_U64: return LLVMInt64TypeInContext(ctx->context); + case TYPE_F64: return LLVMDoubleTypeInContext(ctx->context); + case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context); + case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context); + case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0); + case TYPE_STRUCT: + case TYPE_ENUM: { + // tagged union: { i64 tag, i64 payload } + LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context), + LLVMInt64TypeInContext(ctx->context) }; + return LLVMStructTypeInContext(ctx->context, fields, 2, false); + } + case TYPE_UNKNOWN: + case TYPE_ERROR: + default: return LLVMVoidTypeInContext(ctx->context); + } +} + +LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) { + switch (lit->as.literal.lit_type) { + case TYPE_I32: + case TYPE_I64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, true); + case TYPE_U64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false); + case TYPE_CHAR: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false); + case TYPE_F64: return LLVMConstReal(ty, lit->as.literal.f64_val); + case TYPE_BOOL: return LLVMConstInt(ty, lit->as.literal.bool_val ? 1 : 0, false); + default: return NULL; + } +} + +LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, + LLVMTypeRef from_ty, LLVMTypeRef to_ty) { + if (from_ty == to_ty) return val; + int from_w = LLVMGetIntTypeWidth(from_ty); + int to_w = LLVMGetIntTypeWidth(to_ty); + if (from_w < to_w) + return LLVMBuildSExt(ctx->builder, val, to_ty, "sext"); + else + return LLVMBuildTrunc(ctx->builder, val, to_ty, "trunc"); +} + +// 从 TypeInfo 生成 LLVM 类型(支持数组、结构体等复合类型) +LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti) { + switch (ti->kind) { + case TYPE_ARRAY: { + TypeInfo elem = { .kind = ti->element_type, .struct_name = ti->element_struct_name }; + LLVMTypeRef elem_ty = type_info_to_llvm(ctx, &elem); + return LLVMArrayType(elem_ty, (unsigned)ti->array_size); + } + case TYPE_STRUCT: + if (ti->struct_name) { + LLVMTypeRef st = find_struct_type(ctx, ti->struct_name); + if (st) return st; + } + return LLVMVoidTypeInContext(ctx->context); + case TYPE_ENUM: { + LLVMTypeRef f[] = { LLVMInt64TypeInContext(ctx->context), + LLVMInt64TypeInContext(ctx->context) }; + return LLVMStructTypeInContext(ctx->context, f, 2, false); + } + default: + return to_llvm_type(ctx, ti->kind); + } +} + +// === 向前声明 === +LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node); +void codegen_stmt(CgCtx* ctx, AstNode* node); + +LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { + if (!node) return NULL; + + switch (node->kind) { + case AST_LITERAL_EXPR: + if (node->type.kind == TYPE_STR) { + return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str"); + } + return to_llvm_const(to_llvm_type(ctx, node->type.kind), node); + + case AST_IDENT_EXPR: { + LLVMValueRef ptr = find_var(ctx, node->as.ident.name); + if (!ptr) return NULL; + LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type); + return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load"); + } + + case AST_UNARY_EXPR: { + LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand); + if (!operand) return NULL; + if (node->as.unary.op == OP_NEG) { + if (node->type.kind == TYPE_F64) + return LLVMBuildFNeg(ctx->builder, operand, "fneg"); + else + return LLVMBuildNeg(ctx->builder, operand, "ineg"); + } else { + return LLVMBuildNot(ctx->builder, operand, "not"); + } + } + + case AST_BINARY_EXPR: { + LLVMValueRef l = codegen_expr(ctx, node->as.binary.left); + LLVMValueRef r = codegen_expr(ctx, node->as.binary.right); + if (!l || !r) return NULL; + + // 字符串拼接:alloc 栈缓冲区,strcpy + strcat + if (node->type.kind == TYPE_STR) { + // strlen(left) + LLVMValueRef len_l = LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, + (LLVMValueRef[]){l}, 1, "strlen_l"); + // strlen(right) + LLVMValueRef len_r = LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, + (LLVMValueRef[]){r}, 1, "strlen_r"); + // total = len_l + len_r + 1 + LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total"); + total = LLVMBuildAdd(ctx->builder, total, + LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1"); + // char* buf = malloc(total) + LLVMValueRef buf = LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn, + (LLVMValueRef[]){total}, 1, "str_buf"); + // memcpy(buf, left, len_l) + LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, + (LLVMValueRef[]){buf, l, len_l}, 3, ""); + // memcpy(buf + len_l, right, len_r + 1) -- includes null terminator + LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder, + LLVMInt8TypeInContext(ctx->context), buf, + (LLVMValueRef[]){len_l}, 1, "offset"); + LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r, + LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1"); + LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, + (LLVMValueRef[]){offset_ptr, r, len_r1}, 3, ""); + return buf; + } + + bool is_float = (node->type.kind == TYPE_F64); + + switch (node->as.binary.op) { + case OP_ADD: + return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd") + : LLVMBuildAdd(ctx->builder, l, r, "iadd"); + case OP_SUB: + return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub") + : LLVMBuildSub(ctx->builder, l, r, "isub"); + case OP_MUL: + return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul") + : LLVMBuildMul(ctx->builder, l, r, "imul"); + case OP_DIV: + return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv") + : LLVMBuildSDiv(ctx->builder, l, r, "sdiv"); + case OP_MOD: + return LLVMBuildSRem(ctx->builder, l, r, "srem"); + case OP_EQ: + case OP_NE: + case OP_LT: + case OP_GT: + case OP_LE: + case OP_GE: { + // 枚举比较: 提取 tag 字段再比较 + LLVMValueRef cl = l, cr = r; + if (node->as.binary.left->type.kind == TYPE_ENUM) { + cl = LLVMBuildExtractValue(ctx->builder, l, 0, "enum_tag_l"); + } + if (node->as.binary.right->type.kind == TYPE_ENUM) { + cr = LLVMBuildExtractValue(ctx->builder, r, 0, "enum_tag_r"); + } + LLVMIntPredicate pred; + switch (node->as.binary.op) { + case OP_EQ: pred = LLVMIntEQ; break; + case OP_NE: pred = LLVMIntNE; break; + case OP_LT: pred = LLVMIntSLT; break; + case OP_GT: pred = LLVMIntSGT; break; + case OP_LE: pred = LLVMIntSLE; break; + case OP_GE: pred = LLVMIntSGE; break; + default: return NULL; + } + if (is_float) + return LLVMBuildFCmp(ctx->builder, pred == LLVMIntEQ ? LLVMRealOEQ : + pred == LLVMIntNE ? LLVMRealONE : pred == LLVMIntSLT ? LLVMRealOLT : + pred == LLVMIntSGT ? LLVMRealOGT : pred == LLVMIntSLE ? LLVMRealOLE : + LLVMRealOGE, cl, cr, "fcmp"); + return LLVMBuildICmp(ctx->builder, pred, cl, cr, "icmp"); + } + case OP_AND: + return LLVMBuildAnd(ctx->builder, l, r, "and"); + case OP_OR: + return LLVMBuildOr(ctx->builder, l, r, "or"); + default: + return NULL; + } + } + + case AST_CALL_EXPR: { + // === 内置 print 函数:委托给 printf === + if (strcmp(node->as.call.name, "print_i64") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + // 枚举类型: 提取 tag 字段 + if (node->as.call.args[0]->type.kind == TYPE_ENUM) + arg = LLVMBuildExtractValue(ctx->builder, arg, 0, "tag"); + LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); + arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty); + LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64"); + LLVMValueRef printf_args[] = { fmt, arg }; + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + printf_args, 2, ""); + } + if (strcmp(node->as.call.name, "print_f64") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64"); + LLVMValueRef printf_args[] = { fmt, arg }; + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + printf_args, 2, ""); + } + if (strcmp(node->as.call.name, "print_bool") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + // 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择 + LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg, + LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp"); + LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str"); + LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str"); + LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel"); + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + (LLVMValueRef[]){selected}, 1, ""); + } + if (strcmp(node->as.call.name, "print_str") == 0) { + LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); + if (!arg) return NULL; + LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str"); + LLVMValueRef printf_args[] = { fmt, arg }; + return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, + printf_args, 2, ""); + } + + // === 常规函数调用 === + LLVMValueRef fn = find_fn(ctx, node->as.call.name); + if (!fn) return NULL; + LLVMValueRef args[16]; + if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; } + for (size_t i = 0; i < node->as.call.arg_count; i++) { + args[i] = codegen_expr(ctx, node->as.call.args[i]); + if (!args[i]) return NULL; + } + LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); + LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); + return LLVMBuildCall2(ctx->builder, fn_ty, fn, + args, (unsigned)node->as.call.arg_count, + ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call"); + } + + // === 结构体字段访问: p.x === + case AST_FIELD_ACCESS: { + // 对对象求值(返回的是 struct 值) + LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object); + if (!struct_val) return NULL; + + int field_idx = node->as.field_access.field_index; + if (field_idx < 0) return NULL; // sema 应当已经设置 + + // 用 extractvalue 从结构体值中提取字段 + return LLVMBuildExtractValue(ctx->builder, struct_val, + (unsigned)field_idx, node->as.field_access.field); + } + + // === 结构体初始化: Point { x: 10, y: 20 } === + case AST_STRUCT_INIT: { + const char* st_name = node->as.struct_init.type_name; + LLVMTypeRef struct_ty = find_struct_type(ctx, st_name); + if (!struct_ty) return NULL; + + // alloca 分配结构体空间 + LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init"); + + // 获取结构体字段名列表(从 struct_table 或从 AST 中) + // 对每个 init 字段,找到它在结构体中的索引并 store + for (size_t i = 0; i < node->as.struct_init.field_count; i++) { + AstNode* fval = node->as.struct_init.field_values[i]; + LLVMValueRef val = codegen_expr(ctx, fval); + if (!val) return NULL; + + // 获取字段指针: GEP struct_ty, alloca, 0, i + LLVMValueRef indices[] = { + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false) + }; + LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca, + indices, 2, "field_ptr"); + LLVMBuildStore(ctx->builder, val, field_ptr); + } + + // 加载整个结构体值 + return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val"); + } + + case AST_ENUM_VARIANT: { + // tagged union: { tag, payload } + LLVMValueRef tag = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), + (unsigned long long)node->as.enum_variant.variant_index, true); + LLVMValueRef payload = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 0, true); + if (node->as.enum_variant.payload) { + LLVMValueRef pv = codegen_expr(ctx, node->as.enum_variant.payload); + if (pv) { + // 将 payload 强制转换为 i64 + LLVMTypeRef pv_ty = LLVMTypeOf(pv); + LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); + if (pv_ty != i64_ty && LLVMGetTypeKind(pv_ty) == LLVMIntegerTypeKind) + pv = coerce_int(ctx, pv, pv_ty, i64_ty); + payload = pv; + } + } + LLVMValueRef fields[] = { tag, payload }; + return LLVMConstStruct(fields, 2, false); + } + + case AST_METHOD_CALL: { + const char* struct_name = node->as.method_call.receiver->type.struct_name; + char mangled[256]; + // 若 method_name 已含 $(trait 方法,sema 已设置全限定名),直接用 + if (strchr(node->as.method_call.method_name, '$')) + snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name); + else + snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, + node->as.method_call.method_name); + LLVMValueRef fn = find_fn(ctx, mangled); + if (!fn) return NULL; + // 参数列表: [receiver, 用户参数...] + if (node->as.method_call.arg_count + 1 > 16) { ctx->error = "方法参数过多(最多15)"; return NULL; } + LLVMValueRef args[16]; + args[0] = codegen_expr(ctx, node->as.method_call.receiver); + if (!args[0]) return NULL; + for (size_t i = 0; i < node->as.method_call.arg_count; i++) { + args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]); + if (!args[i + 1]) return NULL; + } + LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); + LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); + return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, + (unsigned)(node->as.method_call.arg_count + 1), + ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call"); + } + + case AST_INDEX_EXPR: { + // 获取数组变量的指针 + AstNode* arr_node = node->as.index_expr.array; + LLVMValueRef arr_ptr = NULL; + LLVMTypeRef arr_gp_type = NULL; + + if (arr_node->kind == AST_IDENT_EXPR) { + arr_ptr = find_var(ctx, arr_node->as.ident.name); + // 从变量表获取数组类型用于 GEP + for (VarEntry* e = ctx->var_table; e; e = e->next) { + if (strcmp(e->name, arr_node->as.ident.name) == 0) { + arr_gp_type = e->alloca_type; break; + } + } + } + if (!arr_ptr || !arr_gp_type) return NULL; + + // 生成索引值 + LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index); + if (!idx_val) return NULL; + + // GEP 索引必须是 i32,但 L 使用 i64。截断。 + LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val, + LLVMInt32TypeInContext(ctx->context), "idx32"); + + LLVMValueRef indices[] = { + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), + idx_i32 + }; + LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem"); + + LLVMTypeRef elem_load_ty; + if (node->type.kind == TYPE_STRUCT && node->type.struct_name) { + elem_load_ty = find_struct_type(ctx, node->type.struct_name); + if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind); + } else { + elem_load_ty = type_info_to_llvm(ctx, &node->type); + } + return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load"); + } + + // 块表达式: { stmt*; expr } → 最后表达式的值 + case AST_BLOCK: { + LLVMValueRef result = NULL; + for (size_t i = 0; i < node->as.block.stmt_count; i++) { + AstNode* stmt = node->as.block.stmts[i]; + bool is_last = (i == node->as.block.stmt_count - 1); + if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) { + result = codegen_expr(ctx, stmt->as.expr_stmt.expr); + } else { + codegen_stmt(ctx, stmt); + } + } + return result; + } + + // if 表达式: if cond { a } else { b } + case AST_IF_STMT: { + if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; } + LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond); + if (!cond_val) return NULL; + LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type); + LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res"); + LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder)); + LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then"); + LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else"); + LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge"); + LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, then_bb); + LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block); + if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca); + LLVMBuildBr(ctx->builder, merge_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, else_bb); + LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block); + if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca); + LLVMBuildBr(ctx->builder, merge_bb); + + LLVMPositionBuilderAtEnd(ctx->builder, merge_bb); + return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val"); + } + + default: + return NULL; + } +} + diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 2343fee..ea5a849 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -1,104 +1,15 @@ -#include "codegen.h" -#include -#include -#include -#include +#include "codegen_internal.h" -// === 递归深度限制 -static int codegen_depth = 0; -#define MAX_CODEGEN_DEPTH 1000 - -// === 内部状态 === -typedef struct VarEntry { - const char* name; - LLVMValueRef alloca; - LLVMTypeRef alloca_type; // 分配的类型(GEP 需要) - struct VarEntry* next; -} VarEntry; - -typedef struct FnEntry { - const char* name; - LLVMValueRef fn; - TypeKind ret; - TypeKind* params; - size_t pc; - struct FnEntry* next; -} FnEntry; - -// 结构体类型映射 -typedef struct StructTypeEntry { - const char* name; - LLVMTypeRef llvm_type; - size_t field_count; - struct StructTypeEntry* next; -} StructTypeEntry; - -typedef struct { - Arena* arena; // 代码生成阶段分配器 - LLVMContextRef context; // LLVM 19+ 需要显式 Context - LLVMModuleRef module; - LLVMBuilderRef builder; - VarEntry* var_table; - const char* error; - FnEntry* fn_table; - StructTypeEntry* struct_table; - // printf 运行时支持(内置 print 函数委托给 printf) - LLVMValueRef printf_fn; - LLVMTypeRef printf_ty; - // 字符串拼接运行时支持 - LLVMValueRef malloc_fn; - LLVMValueRef free_fn; // auto-free 需要 free() - LLVMValueRef strlen_fn; - LLVMValueRef memcpy_fn; - // 自动内存管理: 追踪需要 free 的 str alloca - LLVMValueRef* cleanup_list; - size_t cleanup_count; - size_t cleanup_cap; -} CgCtx; - -// === 类型映射(需要 Context)=== -static LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { - switch (kind) { - case TYPE_I32: return LLVMInt32TypeInContext(ctx->context); - case TYPE_I64: return LLVMInt64TypeInContext(ctx->context); - case TYPE_U64: return LLVMInt64TypeInContext(ctx->context); - case TYPE_F64: return LLVMDoubleTypeInContext(ctx->context); - case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context); - case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context); - case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0); - case TYPE_STRUCT: - case TYPE_ENUM: { - // tagged union: { i64 tag, i64 payload } - LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context), - LLVMInt64TypeInContext(ctx->context) }; - return LLVMStructTypeInContext(ctx->context, fields, 2, false); - } - case TYPE_UNKNOWN: - case TYPE_ERROR: - default: return LLVMVoidTypeInContext(ctx->context); - } -} - -static LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) { - switch (lit->as.literal.lit_type) { - case TYPE_I32: - case TYPE_I64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, true); - case TYPE_U64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false); - case TYPE_CHAR: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false); - case TYPE_F64: return LLVMConstReal(ty, lit->as.literal.f64_val); - case TYPE_BOOL: return LLVMConstInt(ty, lit->as.literal.bool_val ? 1 : 0, false); - default: return NULL; - } -} +int codegen_depth = 0; // === 变量表 === -static LLVMValueRef find_var(CgCtx* ctx, const char* name) { +LLVMValueRef find_var(CgCtx* ctx, const char* name) { for (VarEntry* e = ctx->var_table; e; e = e->next) if (strcmp(e->name, name) == 0) return e->alloca; return NULL; } -static void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) { +void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) { VarEntry* e = arena_alloc(ctx->arena, sizeof(*e)); if (!e) return; e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table; @@ -106,13 +17,13 @@ static void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeR } // === 函数表 === -static LLVMValueRef find_fn(CgCtx* ctx, const char* name) { +LLVMValueRef find_fn(CgCtx* ctx, const char* name) { for (FnEntry* e = ctx->fn_table; e; e = e->next) if (strcmp(e->name, name) == 0) return e->fn; return NULL; } -static void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) { +void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) { FnEntry* e = arena_alloc(ctx->arena, sizeof(*e)); if (!e) return; e->name = name; e->fn = fn; @@ -124,7 +35,7 @@ static void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) { } // === 结构体类型表 === -static void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) { +void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) { StructTypeEntry* e = arena_alloc(ctx->arena, sizeof(*e)); if (!e) return; e->name = name; e->llvm_type = ty; e->field_count = fc; @@ -132,420 +43,15 @@ static void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t ctx->struct_table = e; } -static LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name) { +LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name) { for (StructTypeEntry* e = ctx->struct_table; e; e = e->next) if (strcmp(e->name, name) == 0) return e->llvm_type; return NULL; } // 将整数值强制转换到目标 LLVM 类型(sext/zext/trunc) -static LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, - LLVMTypeRef from_ty, LLVMTypeRef to_ty) { - if (from_ty == to_ty) return val; - int from_w = LLVMGetIntTypeWidth(from_ty); - int to_w = LLVMGetIntTypeWidth(to_ty); - if (from_w < to_w) - return LLVMBuildSExt(ctx->builder, val, to_ty, "sext"); - else - return LLVMBuildTrunc(ctx->builder, val, to_ty, "trunc"); -} - -// 从 TypeInfo 生成 LLVM 类型(支持数组、结构体等复合类型) -static LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti) { - switch (ti->kind) { - case TYPE_ARRAY: { - TypeInfo elem = { .kind = ti->element_type, .struct_name = ti->element_struct_name }; - LLVMTypeRef elem_ty = type_info_to_llvm(ctx, &elem); - return LLVMArrayType(elem_ty, (unsigned)ti->array_size); - } - case TYPE_STRUCT: - if (ti->struct_name) { - LLVMTypeRef st = find_struct_type(ctx, ti->struct_name); - if (st) return st; - } - return LLVMVoidTypeInContext(ctx->context); - case TYPE_ENUM: { - LLVMTypeRef f[] = { LLVMInt64TypeInContext(ctx->context), - LLVMInt64TypeInContext(ctx->context) }; - return LLVMStructTypeInContext(ctx->context, f, 2, false); - } - default: - return to_llvm_type(ctx, ti->kind); - } -} - -// === 向前声明 === -static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node); -static void codegen_stmt(CgCtx* ctx, AstNode* node); - -// === 表达式代码生成 === -static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { - if (!node) return NULL; - - switch (node->kind) { - case AST_LITERAL_EXPR: - if (node->type.kind == TYPE_STR) { - return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str"); - } - return to_llvm_const(to_llvm_type(ctx, node->type.kind), node); - - case AST_IDENT_EXPR: { - LLVMValueRef ptr = find_var(ctx, node->as.ident.name); - if (!ptr) return NULL; - LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type); - return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load"); - } - - case AST_UNARY_EXPR: { - LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand); - if (!operand) return NULL; - if (node->as.unary.op == OP_NEG) { - if (node->type.kind == TYPE_F64) - return LLVMBuildFNeg(ctx->builder, operand, "fneg"); - else - return LLVMBuildNeg(ctx->builder, operand, "ineg"); - } else { - return LLVMBuildNot(ctx->builder, operand, "not"); - } - } - - case AST_BINARY_EXPR: { - LLVMValueRef l = codegen_expr(ctx, node->as.binary.left); - LLVMValueRef r = codegen_expr(ctx, node->as.binary.right); - if (!l || !r) return NULL; - - // 字符串拼接:alloc 栈缓冲区,strcpy + strcat - if (node->type.kind == TYPE_STR) { - // strlen(left) - LLVMValueRef len_l = LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, - (LLVMValueRef[]){l}, 1, "strlen_l"); - // strlen(right) - LLVMValueRef len_r = LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn, - (LLVMValueRef[]){r}, 1, "strlen_r"); - // total = len_l + len_r + 1 - LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total"); - total = LLVMBuildAdd(ctx->builder, total, - LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1"); - // char* buf = malloc(total) - LLVMValueRef buf = LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn, - (LLVMValueRef[]){total}, 1, "str_buf"); - // memcpy(buf, left, len_l) - LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, - (LLVMValueRef[]){buf, l, len_l}, 3, ""); - // memcpy(buf + len_l, right, len_r + 1) -- includes null terminator - LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder, - LLVMInt8TypeInContext(ctx->context), buf, - (LLVMValueRef[]){len_l}, 1, "offset"); - LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r, - LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1"); - LLVMBuildCall2(ctx->builder, - LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn, - (LLVMValueRef[]){offset_ptr, r, len_r1}, 3, ""); - return buf; - } - - bool is_float = (node->type.kind == TYPE_F64); - - switch (node->as.binary.op) { - case OP_ADD: - return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd") - : LLVMBuildAdd(ctx->builder, l, r, "iadd"); - case OP_SUB: - return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub") - : LLVMBuildSub(ctx->builder, l, r, "isub"); - case OP_MUL: - return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul") - : LLVMBuildMul(ctx->builder, l, r, "imul"); - case OP_DIV: - return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv") - : LLVMBuildSDiv(ctx->builder, l, r, "sdiv"); - case OP_MOD: - return LLVMBuildSRem(ctx->builder, l, r, "srem"); - case OP_EQ: - case OP_NE: - case OP_LT: - case OP_GT: - case OP_LE: - case OP_GE: { - // 枚举比较: 提取 tag 字段再比较 - LLVMValueRef cl = l, cr = r; - if (node->as.binary.left->type.kind == TYPE_ENUM) { - cl = LLVMBuildExtractValue(ctx->builder, l, 0, "enum_tag_l"); - } - if (node->as.binary.right->type.kind == TYPE_ENUM) { - cr = LLVMBuildExtractValue(ctx->builder, r, 0, "enum_tag_r"); - } - LLVMIntPredicate pred; - switch (node->as.binary.op) { - case OP_EQ: pred = LLVMIntEQ; break; - case OP_NE: pred = LLVMIntNE; break; - case OP_LT: pred = LLVMIntSLT; break; - case OP_GT: pred = LLVMIntSGT; break; - case OP_LE: pred = LLVMIntSLE; break; - case OP_GE: pred = LLVMIntSGE; break; - default: return NULL; - } - if (is_float) - return LLVMBuildFCmp(ctx->builder, pred == LLVMIntEQ ? LLVMRealOEQ : - pred == LLVMIntNE ? LLVMRealONE : pred == LLVMIntSLT ? LLVMRealOLT : - pred == LLVMIntSGT ? LLVMRealOGT : pred == LLVMIntSLE ? LLVMRealOLE : - LLVMRealOGE, cl, cr, "fcmp"); - return LLVMBuildICmp(ctx->builder, pred, cl, cr, "icmp"); - } - case OP_AND: - return LLVMBuildAnd(ctx->builder, l, r, "and"); - case OP_OR: - return LLVMBuildOr(ctx->builder, l, r, "or"); - default: - return NULL; - } - } - - case AST_CALL_EXPR: { - // === 内置 print 函数:委托给 printf === - if (strcmp(node->as.call.name, "print_i64") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - // 枚举类型: 提取 tag 字段 - if (node->as.call.args[0]->type.kind == TYPE_ENUM) - arg = LLVMBuildExtractValue(ctx->builder, arg, 0, "tag"); - LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); - arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty); - LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64"); - LLVMValueRef printf_args[] = { fmt, arg }; - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - printf_args, 2, ""); - } - if (strcmp(node->as.call.name, "print_f64") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64"); - LLVMValueRef printf_args[] = { fmt, arg }; - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - printf_args, 2, ""); - } - if (strcmp(node->as.call.name, "print_bool") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - // 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择 - LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg, - LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp"); - LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str"); - LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str"); - LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel"); - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - (LLVMValueRef[]){selected}, 1, ""); - } - if (strcmp(node->as.call.name, "print_str") == 0) { - LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]); - if (!arg) return NULL; - LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str"); - LLVMValueRef printf_args[] = { fmt, arg }; - return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, - printf_args, 2, ""); - } - - // === 常规函数调用 === - LLVMValueRef fn = find_fn(ctx, node->as.call.name); - if (!fn) return NULL; - LLVMValueRef args[16]; - if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; } - for (size_t i = 0; i < node->as.call.arg_count; i++) { - args[i] = codegen_expr(ctx, node->as.call.args[i]); - if (!args[i]) return NULL; - } - LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); - LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); - return LLVMBuildCall2(ctx->builder, fn_ty, fn, - args, (unsigned)node->as.call.arg_count, - ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call"); - } - - // === 结构体字段访问: p.x === - case AST_FIELD_ACCESS: { - // 对对象求值(返回的是 struct 值) - LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object); - if (!struct_val) return NULL; - - int field_idx = node->as.field_access.field_index; - if (field_idx < 0) return NULL; // sema 应当已经设置 - - // 用 extractvalue 从结构体值中提取字段 - return LLVMBuildExtractValue(ctx->builder, struct_val, - (unsigned)field_idx, node->as.field_access.field); - } - - // === 结构体初始化: Point { x: 10, y: 20 } === - case AST_STRUCT_INIT: { - const char* st_name = node->as.struct_init.type_name; - LLVMTypeRef struct_ty = find_struct_type(ctx, st_name); - if (!struct_ty) return NULL; - - // alloca 分配结构体空间 - LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init"); - - // 获取结构体字段名列表(从 struct_table 或从 AST 中) - // 对每个 init 字段,找到它在结构体中的索引并 store - for (size_t i = 0; i < node->as.struct_init.field_count; i++) { - AstNode* fval = node->as.struct_init.field_values[i]; - LLVMValueRef val = codegen_expr(ctx, fval); - if (!val) return NULL; - - // 获取字段指针: GEP struct_ty, alloca, 0, i - LLVMValueRef indices[] = { - LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), - LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false) - }; - LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca, - indices, 2, "field_ptr"); - LLVMBuildStore(ctx->builder, val, field_ptr); - } - - // 加载整个结构体值 - return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val"); - } - - case AST_ENUM_VARIANT: { - // tagged union: { tag, payload } - LLVMValueRef tag = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), - (unsigned long long)node->as.enum_variant.variant_index, true); - LLVMValueRef payload = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 0, true); - if (node->as.enum_variant.payload) { - LLVMValueRef pv = codegen_expr(ctx, node->as.enum_variant.payload); - if (pv) { - // 将 payload 强制转换为 i64 - LLVMTypeRef pv_ty = LLVMTypeOf(pv); - LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context); - if (pv_ty != i64_ty && LLVMGetTypeKind(pv_ty) == LLVMIntegerTypeKind) - pv = coerce_int(ctx, pv, pv_ty, i64_ty); - payload = pv; - } - } - LLVMValueRef fields[] = { tag, payload }; - return LLVMConstStruct(fields, 2, false); - } - - case AST_METHOD_CALL: { - const char* struct_name = node->as.method_call.receiver->type.struct_name; - char mangled[256]; - // 若 method_name 已含 $(trait 方法,sema 已设置全限定名),直接用 - if (strchr(node->as.method_call.method_name, '$')) - snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name); - else - snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, - node->as.method_call.method_name); - LLVMValueRef fn = find_fn(ctx, mangled); - if (!fn) return NULL; - // 参数列表: [receiver, 用户参数...] - if (node->as.method_call.arg_count + 1 > 16) { ctx->error = "方法参数过多(最多15)"; return NULL; } - LLVMValueRef args[16]; - args[0] = codegen_expr(ctx, node->as.method_call.receiver); - if (!args[0]) return NULL; - for (size_t i = 0; i < node->as.method_call.arg_count; i++) { - args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]); - if (!args[i + 1]) return NULL; - } - LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); - LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); - return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, - (unsigned)(node->as.method_call.arg_count + 1), - ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call"); - } - - case AST_INDEX_EXPR: { - // 获取数组变量的指针 - AstNode* arr_node = node->as.index_expr.array; - LLVMValueRef arr_ptr = NULL; - LLVMTypeRef arr_gp_type = NULL; - - if (arr_node->kind == AST_IDENT_EXPR) { - arr_ptr = find_var(ctx, arr_node->as.ident.name); - // 从变量表获取数组类型用于 GEP - for (VarEntry* e = ctx->var_table; e; e = e->next) { - if (strcmp(e->name, arr_node->as.ident.name) == 0) { - arr_gp_type = e->alloca_type; break; - } - } - } - if (!arr_ptr || !arr_gp_type) return NULL; - - // 生成索引值 - LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index); - if (!idx_val) return NULL; - - // GEP 索引必须是 i32,但 L 使用 i64。截断。 - LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val, - LLVMInt32TypeInContext(ctx->context), "idx32"); - - LLVMValueRef indices[] = { - LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), - idx_i32 - }; - LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem"); - - LLVMTypeRef elem_load_ty; - if (node->type.kind == TYPE_STRUCT && node->type.struct_name) { - elem_load_ty = find_struct_type(ctx, node->type.struct_name); - if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind); - } else { - elem_load_ty = type_info_to_llvm(ctx, &node->type); - } - return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load"); - } - - // 块表达式: { stmt*; expr } → 最后表达式的值 - case AST_BLOCK: { - LLVMValueRef result = NULL; - for (size_t i = 0; i < node->as.block.stmt_count; i++) { - AstNode* stmt = node->as.block.stmts[i]; - bool is_last = (i == node->as.block.stmt_count - 1); - if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) { - result = codegen_expr(ctx, stmt->as.expr_stmt.expr); - } else { - codegen_stmt(ctx, stmt); - } - } - return result; - } - - // if 表达式: if cond { a } else { b } - case AST_IF_STMT: { - if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; } - LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond); - if (!cond_val) return NULL; - LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type); - LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res"); - LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder)); - LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then"); - LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else"); - LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge"); - LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb); - - LLVMPositionBuilderAtEnd(ctx->builder, then_bb); - LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block); - if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca); - LLVMBuildBr(ctx->builder, merge_bb); - - LLVMPositionBuilderAtEnd(ctx->builder, else_bb); - LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block); - if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca); - LLVMBuildBr(ctx->builder, merge_bb); - - LLVMPositionBuilderAtEnd(ctx->builder, merge_bb); - return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val"); - } - - default: - return NULL; - } -} - // === 自动内存管理: 作用域退出时释放 str 堆分配 === -static void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) { +void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) { if (ctx->cleanup_count >= ctx->cleanup_cap) { size_t new_cap = ctx->cleanup_cap ? ctx->cleanup_cap * 2 : 16; LLVMValueRef* new_list = arena_alloc(ctx->arena, new_cap * sizeof(LLVMValueRef)); @@ -559,7 +65,7 @@ static void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) { } // 释放从 mark 位置开始的所有 str 变量 -static void cleanup_emit(CgCtx* ctx, size_t from_mark) { +void cleanup_emit(CgCtx* ctx, size_t from_mark) { for (size_t j = from_mark; j < ctx->cleanup_count; j++) { LLVMValueRef ptr = ctx->cleanup_list[j]; LLVMValueRef val = LLVMBuildLoad2(ctx->builder, @@ -572,7 +78,7 @@ static void cleanup_emit(CgCtx* ctx, size_t from_mark) { } // === 语句代码生成 === -static void codegen_stmt(CgCtx* ctx, AstNode* node) { +void codegen_stmt(CgCtx* ctx, AstNode* node) { if (!node) return; switch (node->kind) { diff --git a/src/codegen/codegen_internal.h b/src/codegen/codegen_internal.h new file mode 100644 index 0000000..f87aebd --- /dev/null +++ b/src/codegen/codegen_internal.h @@ -0,0 +1,83 @@ +#ifndef CODEGEN_INTERNAL_H +#define CODEGEN_INTERNAL_H + +#include "codegen.h" +#include "ast.h" +#include "arena.h" +#include "l_lang.h" +#include +#include +#include +#include + +// 递归深度限制 +extern int codegen_depth; +#define MAX_CODEGEN_DEPTH 1000 + +// === 内部状态 === +typedef struct VarEntry { + const char* name; + LLVMValueRef alloca; + LLVMTypeRef alloca_type; + struct VarEntry* next; +} VarEntry; + +typedef struct FnEntry { + const char* name; + LLVMValueRef fn; + TypeKind ret; + TypeKind* params; + size_t pc; + struct FnEntry* next; +} FnEntry; + +typedef struct StructTypeEntry { + const char* name; + LLVMTypeRef llvm_type; + size_t field_count; + struct StructTypeEntry* next; +} StructTypeEntry; + +typedef struct { + Arena* arena; + LLVMContextRef context; + LLVMModuleRef module; + LLVMBuilderRef builder; + VarEntry* var_table; + const char* error; + FnEntry* fn_table; + StructTypeEntry* struct_table; + LLVMValueRef printf_fn; + LLVMTypeRef printf_ty; + LLVMValueRef malloc_fn; + LLVMValueRef free_fn; + LLVMValueRef strlen_fn; + LLVMValueRef memcpy_fn; + LLVMValueRef* cleanup_list; + size_t cleanup_count; + size_t cleanup_cap; +} CgCtx; + +// === 类型映射 === +LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind); +LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit); +LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti); +LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, LLVMTypeRef from_ty, LLVMTypeRef to_ty); + +// === 表操作 === +LLVMValueRef find_var(CgCtx* ctx, const char* name); +void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type); +LLVMValueRef find_fn(CgCtx* ctx, const char* name); +void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn); +void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc); +LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name); + +// === 内存清理 === +void cleanup_add(CgCtx* ctx, LLVMValueRef alloca); +void cleanup_emit(CgCtx* ctx, size_t from_mark); + +// === 代码生成函数 === +LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node); +void codegen_stmt(CgCtx* ctx, AstNode* node); + +#endif diff --git a/src/sema/sema.c b/src/sema/sema.c index ee80610..c76377e 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -1,636 +1,6 @@ -#include "sema.h" -#include -#include +#include "sema_internal.h" -// === 泛型单态化: AST 类型替换 === -// 将 AST 中所有匹配 type_param_name 的类型引用替换为 concrete_type -static void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname); - -static void subst_type_info(TypeInfo* ti, const char* tparam, TypeKind concrete, const char* concrete_sname) { - if (ti->kind == TYPE_STRUCT && ti->struct_name && strcmp(ti->struct_name, tparam) == 0) { - ti->kind = concrete; - ti->struct_name = concrete_sname; - } -} - -static void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname) { - if (!node) return; - // 替换节点自身类型 - subst_type_info(&node->type, tparam, concrete, concrete_sname); - switch (node->kind) { - case AST_PROGRAM: - for (size_t i = 0; i < node->as.program.fn_count; i++) - subst_ast_types(node->as.program.functions[i], tparam, concrete, concrete_sname); - break; - case AST_FUNCTION: - if (node->as.function.return_type == TYPE_STRUCT - && node->as.function.return_struct_type_name - && strcmp(node->as.function.return_struct_type_name, tparam) == 0) { - node->as.function.return_type = concrete; - node->as.function.return_struct_type_name = concrete_sname; - } - for (size_t i = 0; i < node->as.function.param_count; i++) - subst_ast_types(node->as.function.params[i], tparam, concrete, concrete_sname); - subst_ast_types(node->as.function.body, tparam, concrete, concrete_sname); - break; - case AST_PARAMETER: - if (node->as.parameter.type == TYPE_STRUCT && node->as.parameter.struct_type_name - && strcmp(node->as.parameter.struct_type_name, tparam) == 0) { - node->as.parameter.type = concrete; - node->as.parameter.struct_type_name = concrete_sname; - } - break; - case AST_BLOCK: - for (size_t i = 0; i < node->as.block.stmt_count; i++) - subst_ast_types(node->as.block.stmts[i], tparam, concrete, concrete_sname); - break; - case AST_LET_STMT: - if (node->as.let_stmt.annot_type == TYPE_STRUCT && node->as.let_stmt.struct_type_name - && strcmp(node->as.let_stmt.struct_type_name, tparam) == 0) { - node->as.let_stmt.annot_type = concrete; - node->as.let_stmt.struct_type_name = concrete_sname; - } - subst_ast_types(node->as.let_stmt.init, tparam, concrete, concrete_sname); - break; - case AST_IF_STMT: - subst_ast_types(node->as.if_stmt.cond, tparam, concrete, concrete_sname); - subst_ast_types(node->as.if_stmt.then_block, tparam, concrete, concrete_sname); - subst_ast_types(node->as.if_stmt.else_block, tparam, concrete, concrete_sname); - break; - case AST_WHILE_STMT: - subst_ast_types(node->as.while_stmt.cond, tparam, concrete, concrete_sname); - subst_ast_types(node->as.while_stmt.body, tparam, concrete, concrete_sname); - break; - case AST_RETURN_STMT: - subst_ast_types(node->as.return_stmt.expr, tparam, concrete, concrete_sname); - break; - case AST_EXPR_STMT: - subst_ast_types(node->as.expr_stmt.expr, tparam, concrete, concrete_sname); - break; - case AST_BINARY_EXPR: - subst_ast_types(node->as.binary.left, tparam, concrete, concrete_sname); - subst_ast_types(node->as.binary.right, tparam, concrete, concrete_sname); - break; - case AST_UNARY_EXPR: - subst_ast_types(node->as.unary.operand, tparam, concrete, concrete_sname); - break; - case AST_CALL_EXPR: - for (size_t i = 0; i < node->as.call.arg_count; i++) - subst_ast_types(node->as.call.args[i], tparam, concrete, concrete_sname); - break; - case AST_ASSIGN_STMT: - subst_ast_types(node->as.assign_stmt.value, tparam, concrete, concrete_sname); - break; - case AST_FIELD_ACCESS: - subst_ast_types(node->as.field_access.object, tparam, concrete, concrete_sname); - break; - case AST_STRUCT_INIT: - for (size_t i = 0; i < node->as.struct_init.field_count; i++) - subst_ast_types(node->as.struct_init.field_values[i], tparam, concrete, concrete_sname); - break; - default: break; - } -} - -// === 类型关系 === -static TypeKind current_return_type = TYPE_VOID; -static const char* current_return_struct_name = NULL; - -// 单态化队列: 泛型函数调用时生成的具象化函数 -static AstNode* mono_queue[256]; -static size_t mono_count = 0; -static Arena* mono_arena = NULL; -static AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板) - -static TypeKind promote(TypeKind a, TypeKind b) { - // 枚举在算术运算中视为 i64 - if (a == TYPE_ENUM) a = TYPE_I64; - if (b == TYPE_ENUM) b = TYPE_I64; - // char 在算术中提升为 i32 - if (a == TYPE_CHAR) a = TYPE_I32; - if (b == TYPE_CHAR) b = TYPE_I32; - if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64; - if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64; - if (a == TYPE_U64 || b == TYPE_U64) return TYPE_U64; - if (a == TYPE_I32 || b == TYPE_I32) return TYPE_I32; - if (a == TYPE_BOOL || b == TYPE_BOOL) return TYPE_BOOL; - return TYPE_ERROR; -} - -static bool is_numeric(TypeKind t) { - return t == TYPE_I32 || t == TYPE_I64 || t == TYPE_U64 - || t == TYPE_F64 || t == TYPE_CHAR || t == TYPE_ENUM; -} -// 隐式类型转换规则: 无损加宽允许,有符号→无符号不允许 -static bool can_implicit_convert(TypeKind from, TypeKind to) { - if (from == to) return true; - // 枚举视为 i64 - if (from == TYPE_ENUM) from = TYPE_I64; - if (to == TYPE_ENUM) to = TYPE_I64; - // char 可转为任意整数 - if (from == TYPE_CHAR) return to == TYPE_I32 || to == TYPE_I64 || to == TYPE_U64 || to == TYPE_F64; - // i32 可加宽 - if (from == TYPE_I32) return to == TYPE_I64 || to == TYPE_F64; - // i64 可转 f64 - if (from == TYPE_I64) return to == TYPE_F64; - // u64 ↔ i64 双向允许(同一位宽,LLVM 同类型) - if (from == TYPE_U64) return to == TYPE_F64 || to == TYPE_I64; - if (from == TYPE_I64) return to == TYPE_F64 || to == TYPE_U64; - return false; -} -static bool is_comparable(TypeKind a, TypeKind b) { - if (a == b) return true; - // 枚举可以参与整数比较 - if ((a == TYPE_I64 && b == TYPE_ENUM) || (a == TYPE_ENUM && b == TYPE_I64)) return true; - return false; -} - -// === 向前声明 === -static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); - -// === 表达式类型检查辅助函数 === -static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); - -static void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - (void)a; - Symbol* sym = scope_lookup(scope, node->as.ident.name); - if (!sym) { - error_add(errors, "", node->loc.line, node->loc.col, - "未定义的变量 '%s'", node->as.ident.name); - node->type.kind = TYPE_ERROR; - } else if (sym->is_type_alias) { - error_add(errors, "", node->loc.line, node->loc.col, - "'%s' 是类型别名,不能作为表达式使用", node->as.ident.name); - node->type.kind = TYPE_ERROR; - } else if (sym->kind == SYM_FUNCTION) { - error_add(errors, "", node->loc.line, node->loc.col, - "'%s' 是函数,不能作为表达式使用", node->as.ident.name); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = sym->type; - if (sym->type == TYPE_STRUCT && sym->struct_type_name) - node->type.struct_name = sym->struct_type_name; - if (sym->type == TYPE_ARRAY) { - node->type.element_type = sym->array_element_type; - node->type.element_struct_name = sym->array_element_struct_name; - node->type.array_size = sym->array_size; - } - } -} - -static void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - analyze_expr(node->as.unary.operand, scope, errors, a); - TypeKind inner = node->as.unary.operand->type.kind; - if (node->as.unary.op == OP_NEG) { - if (!is_numeric(inner)) { - error_add(errors, "", node->loc.line, node->loc.col, - "一元 '-' 只能用于数值类型"); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = inner; - } - } else { // OP_NOT - if (inner != TYPE_BOOL) { - error_add(errors, "", node->loc.line, node->loc.col, - "'!' 只能用于布尔类型,得到 '%s'", type_name(inner)); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = TYPE_BOOL; - } - } -} - -static void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - analyze_expr(node->as.binary.left, scope, errors, a); - analyze_expr(node->as.binary.right, scope, errors, a); - TypeKind l = node->as.binary.left->type.kind; - TypeKind r = node->as.binary.right->type.kind; - if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } - - switch (node->as.binary.op) { - case OP_ADD: - if (l == TYPE_STR || r == TYPE_STR) { - if (l != TYPE_STR || r != TYPE_STR) { - error_add(errors, "", node->loc.line, node->loc.col, - "字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'", - type_name(l), type_name(r)); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = TYPE_STR; - } - } else if (!is_numeric(l) || !is_numeric(r)) { - error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = promote(l, r); - } - break; - case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD: - if (!is_numeric(l) || !is_numeric(r)) { - error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = promote(l, r); - } - break; - case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE: - if (!is_comparable(l, r)) { - error_add(errors, "", node->loc.line, node->loc.col, - "类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r)); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = TYPE_BOOL; - } - break; - case OP_AND: case OP_OR: - if (l != TYPE_BOOL || r != TYPE_BOOL) { - error_add(errors, "", node->loc.line, node->loc.col, "逻辑运算需要布尔类型"); - node->type.kind = TYPE_ERROR; - } else { - node->type.kind = TYPE_BOOL; - } - break; - default: break; - } -} - -// 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用) -static bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname, - size_t idx, AstNode* call_node, Symbol* fn_sym, - ErrorList* errors, Arena* a) { - (void)a; - TypeKind actual = arg->type.kind; - if (actual == TYPE_ERROR) return false; - if (expected == TYPE_STRUCT && expected_sname) { - // 检查是否是泛型类型参数(匹配则接受任意类型) - if (fn_sym && fn_sym->type_params) { - for (size_t t = 0; t < fn_sym->type_param_count; t++) { - if (strcmp(expected_sname, fn_sym->type_params[t]) == 0) - return true; // 泛型参数,接受任意类型 - } - } - const char* actual_name = arg->type.struct_name; - if (actual != TYPE_STRUCT || !actual_name || - strcmp(actual_name, expected_sname) != 0) { - error_add(errors, "", call_node->loc.line, call_node->loc.col, - "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", - idx + 1, expected_sname ? expected_sname : "struct", - actual_name ? actual_name : type_name(actual)); - } - return false; - } - if (actual == expected) return false; - if (expected == TYPE_I64 && actual == TYPE_ENUM) return false; - if (can_implicit_convert(actual, expected)) return false; - if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR - && (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false; - error_add(errors, "", call_node->loc.line, call_node->loc.col, - "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", - idx + 1, type_name(expected), type_name(actual)); - return false; -} - -// 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用) -static bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset, - ErrorList* errors, const char* call_name) { - AstNode** args = node->as.call.args; - const char** arg_names = node->as.call.arg_names; - size_t arg_count = node->as.call.arg_count; - if (!arg_names) return true; - AstNode* reordered[16] = {0}; - for (size_t i = 0; i < arg_count; i++) { - if (arg_names[i]) { - bool found = false; - for (size_t j = param_offset; j < sym->param_count; j++) { - if (sym->param_names && sym->param_names[j] && - strcmp(arg_names[i], sym->param_names[j]) == 0) { - reordered[j - param_offset] = args[i]; - found = true; break; - } - } - if (!found) { - error_add(errors, "", node->loc.line, node->loc.col, - "'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]); - return false; - } - } else { - reordered[i] = args[i]; - } - } - memcpy(args, reordered, arg_count * sizeof(AstNode*)); - return true; -} - -static void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - Symbol* sym = scope_lookup(scope, node->as.call.name); - if (!sym || sym->kind != SYM_FUNCTION) { - error_add(errors, "", node->loc.line, node->loc.col, - "未定义的函数 '%s'", node->as.call.name); - node->type.kind = TYPE_ERROR; - for (size_t i = 0; i < node->as.call.arg_count; i++) - analyze_expr(node->as.call.args[i], scope, errors, a); - return; - } - if (node->as.call.arg_count != sym->param_count) { - error_add(errors, "", node->loc.line, node->loc.col, - "函数 '%s' 需要 %zu 个参数,但提供了 %zu 个", - node->as.call.name, sym->param_count, node->as.call.arg_count); - node->type.kind = TYPE_ERROR; - for (size_t i = 0; i < node->as.call.arg_count; i++) - analyze_expr(node->as.call.args[i], scope, errors, a); - return; - } - if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) { - node->type.kind = TYPE_ERROR; return; - } - for (size_t i = 0; i < node->as.call.arg_count; i++) { - analyze_expr(node->as.call.args[i], scope, errors, a); - bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i], - sym->param_struct_names ? sym->param_struct_names[i] : NULL, - i, node, sym, errors, a); - // 泛型单态化: 创建具象化函数副本并注册 - if (is_generic_param && sym->type_params && sym->type_param_count > 0) { - TypeKind concrete = node->as.call.args[i]->type.kind; - const char* concrete_sn = node->as.call.args[i]->type.struct_name; - // 构造 mangled 名: fn$concrete_type - const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete); - int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1; - char* mname = arena_alloc_impl(a, mname_len); - snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name); - // 检查是否已存在 - Symbol* existing = scope_lookup(scope, mname); - if (!existing && g_program) { - // 查找原始泛型函数 AST 节点 - AstNode* generic_fn = NULL; - for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) { - if (strcmp(g_program->as.program.functions[fn_i]->as.function.name, - node->as.call.name) == 0) { - generic_fn = g_program->as.program.functions[fn_i]; - break; - } - } - if (generic_fn && mono_count < 256) { - // 创建浅拷贝(共享 body,subst_ast_types 修改类型标注) - AstNode* mono_fn = ast_make_function(a, mname, - generic_fn->as.function.params, - generic_fn->as.function.param_count, - generic_fn->as.function.return_type, - generic_fn->as.function.return_struct_type_name, - generic_fn->as.function.body, - false, NULL, 0, - generic_fn->loc); - // 类型替换: T → concrete - subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn); - // 注册到队列 - mono_queue[mono_count++] = mono_fn; - // 注册符号(后续分析会处理函数体) - TypeKind* mpts = mono_fn->as.function.param_count > 0 - ? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL; - for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) { - mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type; - } - scope_insert_function(scope, a, mname, - mono_fn->as.function.return_type, - mono_fn->as.function.return_struct_type_name, - mpts, NULL, NULL, - mono_fn->as.function.param_count, NULL, 0); - } - } - // 重定向调用到单态化函数 - node->as.call.name = mname; - sym = scope_lookup(scope, mname); - if (!sym) { node->type.kind = TYPE_ERROR; return; } - } - } - node->type.kind = sym->return_type; - if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) - node->type.struct_name = sym->return_struct_type_name; -} - -static void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - analyze_expr(node->as.field_access.object, scope, errors, a); - AstNode* obj = node->as.field_access.object; - if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } - if (obj->type.kind != TYPE_STRUCT) { - error_add(errors, "", node->loc.line, node->loc.col, - "类型 '%s' 不是结构体,不能访问字段 '%s'", - type_name(obj->type.kind), node->as.field_access.field); - node->type.kind = TYPE_ERROR; return; - } - const char* struct_name = obj->type.struct_name; - if (!struct_name) { - error_add(errors, "", node->loc.line, node->loc.col, "无法确定结构体类型"); - node->type.kind = TYPE_ERROR; return; - } - Symbol* struct_sym = scope_lookup_struct(scope, struct_name); - if (!struct_sym) { - error_add(errors, "", node->loc.line, node->loc.col, - "未定义的结构体 '%s'", struct_name); - node->type.kind = TYPE_ERROR; return; - } - int fi = scope_struct_field_index(struct_sym, node->as.field_access.field); - if (fi < 0) { - error_add(errors, "", node->loc.line, node->loc.col, - "结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field); - node->type.kind = TYPE_ERROR; return; - } - node->type.kind = struct_sym->struct_field_types[fi]; - node->as.field_access.field_index = fi; - if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names && - struct_sym->struct_field_struct_names[fi]) - node->type.struct_name = struct_sym->struct_field_struct_names[fi]; -} - -static void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - const char* resolved = node->as.struct_init.type_name; - Symbol* struct_sym = scope_lookup_struct(scope, resolved); - if (!struct_sym) { - Symbol* alias_sym = scope_lookup(scope, resolved); - if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) { - resolved = alias_sym->struct_type_name; - struct_sym = scope_lookup_struct(scope, resolved); - node->as.struct_init.type_name = resolved; - } - } - if (!struct_sym) { - error_add(errors, "", node->loc.line, node->loc.col, - "未定义的结构体类型 '%s'", node->as.struct_init.type_name); - node->type.kind = TYPE_ERROR; return; - } - if (node->as.struct_init.field_count != struct_sym->struct_field_count) { - error_add(errors, "", node->loc.line, node->loc.col, - "结构体 '%s' 有 %zu 个字段,但提供了 %zu 个", - node->as.struct_init.type_name, - struct_sym->struct_field_count, node->as.struct_init.field_count); - node->type.kind = TYPE_ERROR; return; - } - for (size_t i = 0; i < node->as.struct_init.field_count; i++) { - const char* fname = node->as.struct_init.field_names[i]; - AstNode* fval = node->as.struct_init.field_values[i]; - analyze_expr(fval, scope, errors, a); - int fi = scope_struct_field_index(struct_sym, fname); - if (fi < 0) { - error_add(errors, "", node->loc.line, node->loc.col, - "结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname); - node->type.kind = TYPE_ERROR; continue; - } - TypeKind expected = struct_sym->struct_field_types[fi]; - TypeKind actual = fval->type.kind; - if (actual != TYPE_ERROR && actual != expected) - error_add(errors, "", node->loc.line, node->loc.col, - "字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'", - fname, type_name(expected), type_name(actual)); - } - if (node->type.kind != TYPE_ERROR) { - node->type.kind = TYPE_STRUCT; - node->type.struct_name = resolved; - } -} - -static void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - (void)a; - Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name); - if (!enum_sym || enum_sym->kind != SYM_ENUM) { - error_add(errors, "", node->loc.line, node->loc.col, - "未定义的枚举 '%s'", node->as.enum_variant.enum_name); - node->type.kind = TYPE_ERROR; return; - } - int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name); - if (vi < 0) { - error_add(errors, "", node->loc.line, node->loc.col, - "枚举 '%s' 没有变体 '%s'", - node->as.enum_variant.enum_name, node->as.enum_variant.variant_name); - node->type.kind = TYPE_ERROR; return; - } - node->as.enum_variant.variant_index = vi; - // ADT: 检查 payload - TypeKind expected_pt = TYPE_VOID; - if (enum_sym->variant_payload_types) - expected_pt = enum_sym->variant_payload_types[vi]; - if (node->as.enum_variant.payload) { - if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) { - error_add(errors, "", node->loc.line, node->loc.col, - "枚举变体 '%s::%s' 不接受 payload", - node->as.enum_variant.enum_name, node->as.enum_variant.variant_name); - node->type.kind = TYPE_ERROR; return; - } - analyze_expr(node->as.enum_variant.payload, scope, errors, a); - TypeKind actual = node->as.enum_variant.payload->type.kind; - if (actual != TYPE_ERROR && actual != expected_pt) { - error_add(errors, "", node->loc.line, node->loc.col, - "枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'", - type_name(expected_pt), type_name(actual)); - } - } - node->type.kind = TYPE_ENUM; -} - -static void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - analyze_expr(node->as.index_expr.array, scope, errors, a); - analyze_expr(node->as.index_expr.index, scope, errors, a); - AstNode* arr = node->as.index_expr.array; - AstNode* idx = node->as.index_expr.index; - if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } - if (arr->type.kind != TYPE_ARRAY) { - error_add(errors, "", node->loc.line, node->loc.col, - "类型 '%s' 不支持索引操作", type_name(arr->type.kind)); - node->type.kind = TYPE_ERROR; return; - } - if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } - if (idx->type.kind != TYPE_I64) { - error_add(errors, "", node->loc.line, node->loc.col, - "数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind)); - node->type.kind = TYPE_ERROR; return; - } - node->type.kind = arr->type.element_type; - node->type.struct_name = arr->type.element_struct_name; -} - -static void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - analyze_expr(node->as.method_call.receiver, scope, errors, a); - const char* recv_struct = node->as.method_call.receiver->type.struct_name; - if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) { - error_add(errors, "", node->loc.line, node->loc.col, - "只有结构体类型支持方法调用"); - node->type.kind = TYPE_ERROR; return; - } - char mangled[256]; - snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct, - node->as.method_call.method_name); - Symbol* sym = scope_lookup(scope, mangled); - // trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾且以 StructName 开头的符号 - if (!sym || sym->kind != SYM_FUNCTION) { - char suffix[256]; - snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name); - size_t suf_len = strlen(suffix); - size_t recv_len = strlen(recv_struct); - for (const Scope* sc = scope; sc; sc = sc->parent) { - for (Symbol* s = sc->head; s; s = s->next) { - if (s->kind == SYM_FUNCTION) { - size_t name_len = strlen(s->name); - if (name_len > suf_len + recv_len - && strncmp(s->name, recv_struct, recv_len) == 0 - && strcmp(s->name + name_len - suf_len, suffix) == 0) { - sym = s; - break; - } - } - } - if (sym) break; - } - } - // 更新 method_name 为符号的实际名称(codegen 需要通过它找到 LLVM 函数) - if (sym && sym->kind == SYM_FUNCTION) { - node->as.method_call.method_name = sym->name; - } - if (!sym || sym->kind != SYM_FUNCTION) { - error_add(errors, "", node->loc.line, node->loc.col, - "结构体 '%s' 没有方法 '%s'", recv_struct, - node->as.method_call.method_name); - node->type.kind = TYPE_ERROR; return; - } - if (node->as.method_call.arg_count + 1 != sym->param_count) { - error_add(errors, "", node->loc.line, node->loc.col, - "方法 '%s' 需要 %zu 个参数,提供了 %zu 个", - node->as.method_call.method_name, - sym->param_count > 0 ? sym->param_count - 1 : 0, - node->as.method_call.arg_count); - node->type.kind = TYPE_ERROR; return; - } - if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) { - node->type.kind = TYPE_ERROR; return; - } - for (size_t i = 0; i < node->as.method_call.arg_count; i++) { - analyze_expr(node->as.method_call.args[i], scope, errors, a); - check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1], - sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL, - i, node, sym, errors, a); - } - node->type.kind = sym->return_type; - if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) - node->type.struct_name = sym->return_struct_type_name; -} - -// === 表达式类型检查(调度器) === -static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { - switch (node->kind) { - case AST_LITERAL_EXPR: break; - case AST_IDENT_EXPR: analyze_ident_expr(node, scope, errors, a); break; - case AST_UNARY_EXPR: analyze_unary_expr(node, scope, errors, a); break; - case AST_BINARY_EXPR: analyze_binary_expr(node, scope, errors, a); break; - case AST_CALL_EXPR: analyze_call_expr(node, scope, errors, a); break; - case AST_FIELD_ACCESS: analyze_field_access(node, scope, errors, a); break; - case AST_STRUCT_INIT: analyze_struct_init(node, scope, errors, a); break; - case AST_ENUM_VARIANT: analyze_enum_variant(node, scope, errors, a); break; - case AST_INDEX_EXPR: analyze_index_expr(node, scope, errors, a); break; - case AST_METHOD_CALL: analyze_method_call(node, scope, errors, a); break; - case AST_IF_STMT: analyze_node(node, scope, errors, a); break; - case AST_BLOCK: analyze_node(node, scope, errors, a); break; - default: break; - } -} - -static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { +void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { if (!node) return; switch (node->kind) { diff --git a/src/sema/sema_internal.h b/src/sema/sema_internal.h new file mode 100644 index 0000000..31f9c6e --- /dev/null +++ b/src/sema/sema_internal.h @@ -0,0 +1,51 @@ +#ifndef SEMA_INTERNAL_H +#define SEMA_INTERNAL_H + +#include "sema.h" +#include "symbol.h" +#include "ast.h" +#include "error.h" +#include "arena.h" +#include "l_lang.h" +#include +#include + +// === 泛型单态化队列 === +extern AstNode* mono_queue[256]; +extern size_t mono_count; +extern Arena* mono_arena; +extern AstNode* g_program; + +// === 类型推断上下文 === +extern TypeKind current_return_type; +extern const char* current_return_struct_name; + +// === 类型关系 === +TypeKind promote(TypeKind a, TypeKind b); +bool is_numeric(TypeKind t); +bool can_implicit_convert(TypeKind from, TypeKind to); +bool is_comparable(TypeKind a, TypeKind b); + +// === 泛型单态化 === +void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname); +void subst_type_info(TypeInfo* ti, const char* tparam, TypeKind concrete, const char* concrete_sname); + +// === 表达式类型检查 === +void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); +bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname, + size_t idx, AstNode* call_node, Symbol* fn_sym, + ErrorList* errors, Arena* a); +bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset, + ErrorList* errors, const char* call_name); + +#endif diff --git a/src/sema/typeck.c b/src/sema/typeck.c new file mode 100644 index 0000000..cb7d339 --- /dev/null +++ b/src/sema/typeck.c @@ -0,0 +1,629 @@ +#include "sema_internal.h" + +// === 泛型单态化: AST 类型替换 === +// 将 AST 中所有匹配 type_param_name 的类型引用替换为 concrete_type +void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname); + +void subst_type_info(TypeInfo* ti, const char* tparam, TypeKind concrete, const char* concrete_sname) { + if (ti->kind == TYPE_STRUCT && ti->struct_name && strcmp(ti->struct_name, tparam) == 0) { + ti->kind = concrete; + ti->struct_name = concrete_sname; + } +} + +void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname) { + if (!node) return; + // 替换节点自身类型 + subst_type_info(&node->type, tparam, concrete, concrete_sname); + switch (node->kind) { + case AST_PROGRAM: + for (size_t i = 0; i < node->as.program.fn_count; i++) + subst_ast_types(node->as.program.functions[i], tparam, concrete, concrete_sname); + break; + case AST_FUNCTION: + if (node->as.function.return_type == TYPE_STRUCT + && node->as.function.return_struct_type_name + && strcmp(node->as.function.return_struct_type_name, tparam) == 0) { + node->as.function.return_type = concrete; + node->as.function.return_struct_type_name = concrete_sname; + } + for (size_t i = 0; i < node->as.function.param_count; i++) + subst_ast_types(node->as.function.params[i], tparam, concrete, concrete_sname); + subst_ast_types(node->as.function.body, tparam, concrete, concrete_sname); + break; + case AST_PARAMETER: + if (node->as.parameter.type == TYPE_STRUCT && node->as.parameter.struct_type_name + && strcmp(node->as.parameter.struct_type_name, tparam) == 0) { + node->as.parameter.type = concrete; + node->as.parameter.struct_type_name = concrete_sname; + } + break; + case AST_BLOCK: + for (size_t i = 0; i < node->as.block.stmt_count; i++) + subst_ast_types(node->as.block.stmts[i], tparam, concrete, concrete_sname); + break; + case AST_LET_STMT: + if (node->as.let_stmt.annot_type == TYPE_STRUCT && node->as.let_stmt.struct_type_name + && strcmp(node->as.let_stmt.struct_type_name, tparam) == 0) { + node->as.let_stmt.annot_type = concrete; + node->as.let_stmt.struct_type_name = concrete_sname; + } + subst_ast_types(node->as.let_stmt.init, tparam, concrete, concrete_sname); + break; + case AST_IF_STMT: + subst_ast_types(node->as.if_stmt.cond, tparam, concrete, concrete_sname); + subst_ast_types(node->as.if_stmt.then_block, tparam, concrete, concrete_sname); + subst_ast_types(node->as.if_stmt.else_block, tparam, concrete, concrete_sname); + break; + case AST_WHILE_STMT: + subst_ast_types(node->as.while_stmt.cond, tparam, concrete, concrete_sname); + subst_ast_types(node->as.while_stmt.body, tparam, concrete, concrete_sname); + break; + case AST_RETURN_STMT: + subst_ast_types(node->as.return_stmt.expr, tparam, concrete, concrete_sname); + break; + case AST_EXPR_STMT: + subst_ast_types(node->as.expr_stmt.expr, tparam, concrete, concrete_sname); + break; + case AST_BINARY_EXPR: + subst_ast_types(node->as.binary.left, tparam, concrete, concrete_sname); + subst_ast_types(node->as.binary.right, tparam, concrete, concrete_sname); + break; + case AST_UNARY_EXPR: + subst_ast_types(node->as.unary.operand, tparam, concrete, concrete_sname); + break; + case AST_CALL_EXPR: + for (size_t i = 0; i < node->as.call.arg_count; i++) + subst_ast_types(node->as.call.args[i], tparam, concrete, concrete_sname); + break; + case AST_ASSIGN_STMT: + subst_ast_types(node->as.assign_stmt.value, tparam, concrete, concrete_sname); + break; + case AST_FIELD_ACCESS: + subst_ast_types(node->as.field_access.object, tparam, concrete, concrete_sname); + break; + case AST_STRUCT_INIT: + for (size_t i = 0; i < node->as.struct_init.field_count; i++) + subst_ast_types(node->as.struct_init.field_values[i], tparam, concrete, concrete_sname); + break; + default: break; + } +} + +// === 类型关系 === +TypeKind current_return_type = TYPE_VOID; +const char* current_return_struct_name = NULL; + +// 单态化队列: 泛型函数调用时生成的具象化函数 +AstNode* mono_queue[256]; +size_t mono_count = 0; +Arena* mono_arena = NULL; +AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板) + +TypeKind promote(TypeKind a, TypeKind b) { + // 枚举在算术运算中视为 i64 + if (a == TYPE_ENUM) a = TYPE_I64; + if (b == TYPE_ENUM) b = TYPE_I64; + // char 在算术中提升为 i32 + if (a == TYPE_CHAR) a = TYPE_I32; + if (b == TYPE_CHAR) b = TYPE_I32; + if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64; + if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64; + if (a == TYPE_U64 || b == TYPE_U64) return TYPE_U64; + if (a == TYPE_I32 || b == TYPE_I32) return TYPE_I32; + if (a == TYPE_BOOL || b == TYPE_BOOL) return TYPE_BOOL; + return TYPE_ERROR; +} + +bool is_numeric(TypeKind t) { + return t == TYPE_I32 || t == TYPE_I64 || t == TYPE_U64 + || t == TYPE_F64 || t == TYPE_CHAR || t == TYPE_ENUM; +} +// 隐式类型转换规则: 无损加宽允许,有符号→无符号不允许 +bool can_implicit_convert(TypeKind from, TypeKind to) { + if (from == to) return true; + // 枚举视为 i64 + if (from == TYPE_ENUM) from = TYPE_I64; + if (to == TYPE_ENUM) to = TYPE_I64; + // char 可转为任意整数 + if (from == TYPE_CHAR) return to == TYPE_I32 || to == TYPE_I64 || to == TYPE_U64 || to == TYPE_F64; + // i32 可加宽 + if (from == TYPE_I32) return to == TYPE_I64 || to == TYPE_F64; + // i64 可转 f64 + if (from == TYPE_I64) return to == TYPE_F64; + // u64 ↔ i64 双向允许(同一位宽,LLVM 同类型) + if (from == TYPE_U64) return to == TYPE_F64 || to == TYPE_I64; + if (from == TYPE_I64) return to == TYPE_F64 || to == TYPE_U64; + return false; +} +bool is_comparable(TypeKind a, TypeKind b) { + if (a == b) return true; + // 枚举可以参与整数比较 + if ((a == TYPE_I64 && b == TYPE_ENUM) || (a == TYPE_ENUM && b == TYPE_I64)) return true; + return false; +} + +// === 向前声明 === +void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); + +// === 表达式类型检查辅助函数 === +void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); + +void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + (void)a; + Symbol* sym = scope_lookup(scope, node->as.ident.name); + if (!sym) { + error_add(errors, "", node->loc.line, node->loc.col, + "未定义的变量 '%s'", node->as.ident.name); + node->type.kind = TYPE_ERROR; + } else if (sym->is_type_alias) { + error_add(errors, "", node->loc.line, node->loc.col, + "'%s' 是类型别名,不能作为表达式使用", node->as.ident.name); + node->type.kind = TYPE_ERROR; + } else if (sym->kind == SYM_FUNCTION) { + error_add(errors, "", node->loc.line, node->loc.col, + "'%s' 是函数,不能作为表达式使用", node->as.ident.name); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = sym->type; + if (sym->type == TYPE_STRUCT && sym->struct_type_name) + node->type.struct_name = sym->struct_type_name; + if (sym->type == TYPE_ARRAY) { + node->type.element_type = sym->array_element_type; + node->type.element_struct_name = sym->array_element_struct_name; + node->type.array_size = sym->array_size; + } + } +} + +void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + analyze_expr(node->as.unary.operand, scope, errors, a); + TypeKind inner = node->as.unary.operand->type.kind; + if (node->as.unary.op == OP_NEG) { + if (!is_numeric(inner)) { + error_add(errors, "", node->loc.line, node->loc.col, + "一元 '-' 只能用于数值类型"); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = inner; + } + } else { // OP_NOT + if (inner != TYPE_BOOL) { + error_add(errors, "", node->loc.line, node->loc.col, + "'!' 只能用于布尔类型,得到 '%s'", type_name(inner)); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = TYPE_BOOL; + } + } +} + +void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + analyze_expr(node->as.binary.left, scope, errors, a); + analyze_expr(node->as.binary.right, scope, errors, a); + TypeKind l = node->as.binary.left->type.kind; + TypeKind r = node->as.binary.right->type.kind; + if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } + + switch (node->as.binary.op) { + case OP_ADD: + if (l == TYPE_STR || r == TYPE_STR) { + if (l != TYPE_STR || r != TYPE_STR) { + error_add(errors, "", node->loc.line, node->loc.col, + "字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'", + type_name(l), type_name(r)); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = TYPE_STR; + } + } else if (!is_numeric(l) || !is_numeric(r)) { + error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = promote(l, r); + } + break; + case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD: + if (!is_numeric(l) || !is_numeric(r)) { + error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = promote(l, r); + } + break; + case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE: + if (!is_comparable(l, r)) { + error_add(errors, "", node->loc.line, node->loc.col, + "类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r)); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = TYPE_BOOL; + } + break; + case OP_AND: case OP_OR: + if (l != TYPE_BOOL || r != TYPE_BOOL) { + error_add(errors, "", node->loc.line, node->loc.col, "逻辑运算需要布尔类型"); + node->type.kind = TYPE_ERROR; + } else { + node->type.kind = TYPE_BOOL; + } + break; + default: break; + } +} + +// 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用) +bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname, + size_t idx, AstNode* call_node, Symbol* fn_sym, + ErrorList* errors, Arena* a) { + (void)a; + TypeKind actual = arg->type.kind; + if (actual == TYPE_ERROR) return false; + if (expected == TYPE_STRUCT && expected_sname) { + // 检查是否是泛型类型参数(匹配则接受任意类型) + if (fn_sym && fn_sym->type_params) { + for (size_t t = 0; t < fn_sym->type_param_count; t++) { + if (strcmp(expected_sname, fn_sym->type_params[t]) == 0) + return true; // 泛型参数,接受任意类型 + } + } + const char* actual_name = arg->type.struct_name; + if (actual != TYPE_STRUCT || !actual_name || + strcmp(actual_name, expected_sname) != 0) { + error_add(errors, "", call_node->loc.line, call_node->loc.col, + "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", + idx + 1, expected_sname ? expected_sname : "struct", + actual_name ? actual_name : type_name(actual)); + } + return false; + } + if (actual == expected) return false; + if (expected == TYPE_I64 && actual == TYPE_ENUM) return false; + if (can_implicit_convert(actual, expected)) return false; + if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR + && (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false; + error_add(errors, "", call_node->loc.line, call_node->loc.col, + "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", + idx + 1, type_name(expected), type_name(actual)); + return false; +} + +// 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用) +bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset, + ErrorList* errors, const char* call_name) { + AstNode** args = node->as.call.args; + const char** arg_names = node->as.call.arg_names; + size_t arg_count = node->as.call.arg_count; + if (!arg_names) return true; + AstNode* reordered[16] = {0}; + for (size_t i = 0; i < arg_count; i++) { + if (arg_names[i]) { + bool found = false; + for (size_t j = param_offset; j < sym->param_count; j++) { + if (sym->param_names && sym->param_names[j] && + strcmp(arg_names[i], sym->param_names[j]) == 0) { + reordered[j - param_offset] = args[i]; + found = true; break; + } + } + if (!found) { + error_add(errors, "", node->loc.line, node->loc.col, + "'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]); + return false; + } + } else { + reordered[i] = args[i]; + } + } + memcpy(args, reordered, arg_count * sizeof(AstNode*)); + return true; +} + +void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + Symbol* sym = scope_lookup(scope, node->as.call.name); + if (!sym || sym->kind != SYM_FUNCTION) { + error_add(errors, "", node->loc.line, node->loc.col, + "未定义的函数 '%s'", node->as.call.name); + node->type.kind = TYPE_ERROR; + for (size_t i = 0; i < node->as.call.arg_count; i++) + analyze_expr(node->as.call.args[i], scope, errors, a); + return; + } + if (node->as.call.arg_count != sym->param_count) { + error_add(errors, "", node->loc.line, node->loc.col, + "函数 '%s' 需要 %zu 个参数,但提供了 %zu 个", + node->as.call.name, sym->param_count, node->as.call.arg_count); + node->type.kind = TYPE_ERROR; + for (size_t i = 0; i < node->as.call.arg_count; i++) + analyze_expr(node->as.call.args[i], scope, errors, a); + return; + } + if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) { + node->type.kind = TYPE_ERROR; return; + } + for (size_t i = 0; i < node->as.call.arg_count; i++) { + analyze_expr(node->as.call.args[i], scope, errors, a); + bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i], + sym->param_struct_names ? sym->param_struct_names[i] : NULL, + i, node, sym, errors, a); + // 泛型单态化: 创建具象化函数副本并注册 + if (is_generic_param && sym->type_params && sym->type_param_count > 0) { + TypeKind concrete = node->as.call.args[i]->type.kind; + const char* concrete_sn = node->as.call.args[i]->type.struct_name; + // 构造 mangled 名: fn$concrete_type + const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete); + int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1; + char* mname = arena_alloc_impl(a, mname_len); + snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name); + // 检查是否已存在 + Symbol* existing = scope_lookup(scope, mname); + if (!existing && g_program) { + // 查找原始泛型函数 AST 节点 + AstNode* generic_fn = NULL; + for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) { + if (strcmp(g_program->as.program.functions[fn_i]->as.function.name, + node->as.call.name) == 0) { + generic_fn = g_program->as.program.functions[fn_i]; + break; + } + } + if (generic_fn && mono_count < 256) { + // 创建浅拷贝(共享 body,subst_ast_types 修改类型标注) + AstNode* mono_fn = ast_make_function(a, mname, + generic_fn->as.function.params, + generic_fn->as.function.param_count, + generic_fn->as.function.return_type, + generic_fn->as.function.return_struct_type_name, + generic_fn->as.function.body, + false, NULL, 0, + generic_fn->loc); + // 类型替换: T → concrete + subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn); + // 注册到队列 + mono_queue[mono_count++] = mono_fn; + // 注册符号(后续分析会处理函数体) + TypeKind* mpts = mono_fn->as.function.param_count > 0 + ? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL; + for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) { + mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type; + } + scope_insert_function(scope, a, mname, + mono_fn->as.function.return_type, + mono_fn->as.function.return_struct_type_name, + mpts, NULL, NULL, + mono_fn->as.function.param_count, NULL, 0); + } + } + // 重定向调用到单态化函数 + node->as.call.name = mname; + sym = scope_lookup(scope, mname); + if (!sym) { node->type.kind = TYPE_ERROR; return; } + } + } + node->type.kind = sym->return_type; + if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) + node->type.struct_name = sym->return_struct_type_name; +} + +void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + analyze_expr(node->as.field_access.object, scope, errors, a); + AstNode* obj = node->as.field_access.object; + if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } + if (obj->type.kind != TYPE_STRUCT) { + error_add(errors, "", node->loc.line, node->loc.col, + "类型 '%s' 不是结构体,不能访问字段 '%s'", + type_name(obj->type.kind), node->as.field_access.field); + node->type.kind = TYPE_ERROR; return; + } + const char* struct_name = obj->type.struct_name; + if (!struct_name) { + error_add(errors, "", node->loc.line, node->loc.col, "无法确定结构体类型"); + node->type.kind = TYPE_ERROR; return; + } + Symbol* struct_sym = scope_lookup_struct(scope, struct_name); + if (!struct_sym) { + error_add(errors, "", node->loc.line, node->loc.col, + "未定义的结构体 '%s'", struct_name); + node->type.kind = TYPE_ERROR; return; + } + int fi = scope_struct_field_index(struct_sym, node->as.field_access.field); + if (fi < 0) { + error_add(errors, "", node->loc.line, node->loc.col, + "结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field); + node->type.kind = TYPE_ERROR; return; + } + node->type.kind = struct_sym->struct_field_types[fi]; + node->as.field_access.field_index = fi; + if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names && + struct_sym->struct_field_struct_names[fi]) + node->type.struct_name = struct_sym->struct_field_struct_names[fi]; +} + +void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + const char* resolved = node->as.struct_init.type_name; + Symbol* struct_sym = scope_lookup_struct(scope, resolved); + if (!struct_sym) { + Symbol* alias_sym = scope_lookup(scope, resolved); + if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) { + resolved = alias_sym->struct_type_name; + struct_sym = scope_lookup_struct(scope, resolved); + node->as.struct_init.type_name = resolved; + } + } + if (!struct_sym) { + error_add(errors, "", node->loc.line, node->loc.col, + "未定义的结构体类型 '%s'", node->as.struct_init.type_name); + node->type.kind = TYPE_ERROR; return; + } + if (node->as.struct_init.field_count != struct_sym->struct_field_count) { + error_add(errors, "", node->loc.line, node->loc.col, + "结构体 '%s' 有 %zu 个字段,但提供了 %zu 个", + node->as.struct_init.type_name, + struct_sym->struct_field_count, node->as.struct_init.field_count); + node->type.kind = TYPE_ERROR; return; + } + for (size_t i = 0; i < node->as.struct_init.field_count; i++) { + const char* fname = node->as.struct_init.field_names[i]; + AstNode* fval = node->as.struct_init.field_values[i]; + analyze_expr(fval, scope, errors, a); + int fi = scope_struct_field_index(struct_sym, fname); + if (fi < 0) { + error_add(errors, "", node->loc.line, node->loc.col, + "结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname); + node->type.kind = TYPE_ERROR; continue; + } + TypeKind expected = struct_sym->struct_field_types[fi]; + TypeKind actual = fval->type.kind; + if (actual != TYPE_ERROR && actual != expected) + error_add(errors, "", node->loc.line, node->loc.col, + "字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'", + fname, type_name(expected), type_name(actual)); + } + if (node->type.kind != TYPE_ERROR) { + node->type.kind = TYPE_STRUCT; + node->type.struct_name = resolved; + } +} + +void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + (void)a; + Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name); + if (!enum_sym || enum_sym->kind != SYM_ENUM) { + error_add(errors, "", node->loc.line, node->loc.col, + "未定义的枚举 '%s'", node->as.enum_variant.enum_name); + node->type.kind = TYPE_ERROR; return; + } + int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name); + if (vi < 0) { + error_add(errors, "", node->loc.line, node->loc.col, + "枚举 '%s' 没有变体 '%s'", + node->as.enum_variant.enum_name, node->as.enum_variant.variant_name); + node->type.kind = TYPE_ERROR; return; + } + node->as.enum_variant.variant_index = vi; + // ADT: 检查 payload + TypeKind expected_pt = TYPE_VOID; + if (enum_sym->variant_payload_types) + expected_pt = enum_sym->variant_payload_types[vi]; + if (node->as.enum_variant.payload) { + if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) { + error_add(errors, "", node->loc.line, node->loc.col, + "枚举变体 '%s::%s' 不接受 payload", + node->as.enum_variant.enum_name, node->as.enum_variant.variant_name); + node->type.kind = TYPE_ERROR; return; + } + analyze_expr(node->as.enum_variant.payload, scope, errors, a); + TypeKind actual = node->as.enum_variant.payload->type.kind; + if (actual != TYPE_ERROR && actual != expected_pt) { + error_add(errors, "", node->loc.line, node->loc.col, + "枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'", + type_name(expected_pt), type_name(actual)); + } + } + node->type.kind = TYPE_ENUM; +} + +void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + analyze_expr(node->as.index_expr.array, scope, errors, a); + analyze_expr(node->as.index_expr.index, scope, errors, a); + AstNode* arr = node->as.index_expr.array; + AstNode* idx = node->as.index_expr.index; + if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } + if (arr->type.kind != TYPE_ARRAY) { + error_add(errors, "", node->loc.line, node->loc.col, + "类型 '%s' 不支持索引操作", type_name(arr->type.kind)); + node->type.kind = TYPE_ERROR; return; + } + if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } + if (idx->type.kind != TYPE_I64) { + error_add(errors, "", node->loc.line, node->loc.col, + "数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind)); + node->type.kind = TYPE_ERROR; return; + } + node->type.kind = arr->type.element_type; + node->type.struct_name = arr->type.element_struct_name; +} + +void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + analyze_expr(node->as.method_call.receiver, scope, errors, a); + const char* recv_struct = node->as.method_call.receiver->type.struct_name; + if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) { + error_add(errors, "", node->loc.line, node->loc.col, + "只有结构体类型支持方法调用"); + node->type.kind = TYPE_ERROR; return; + } + char mangled[256]; + snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct, + node->as.method_call.method_name); + Symbol* sym = scope_lookup(scope, mangled); + // trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾且以 StructName 开头的符号 + if (!sym || sym->kind != SYM_FUNCTION) { + char suffix[256]; + snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name); + size_t suf_len = strlen(suffix); + size_t recv_len = strlen(recv_struct); + for (const Scope* sc = scope; sc; sc = sc->parent) { + for (Symbol* s = sc->head; s; s = s->next) { + if (s->kind == SYM_FUNCTION) { + size_t name_len = strlen(s->name); + if (name_len > suf_len + recv_len + && strncmp(s->name, recv_struct, recv_len) == 0 + && strcmp(s->name + name_len - suf_len, suffix) == 0) { + sym = s; + break; + } + } + } + if (sym) break; + } + } + // 更新 method_name 为符号的实际名称(codegen 需要通过它找到 LLVM 函数) + if (sym && sym->kind == SYM_FUNCTION) { + node->as.method_call.method_name = sym->name; + } + if (!sym || sym->kind != SYM_FUNCTION) { + error_add(errors, "", node->loc.line, node->loc.col, + "结构体 '%s' 没有方法 '%s'", recv_struct, + node->as.method_call.method_name); + node->type.kind = TYPE_ERROR; return; + } + if (node->as.method_call.arg_count + 1 != sym->param_count) { + error_add(errors, "", node->loc.line, node->loc.col, + "方法 '%s' 需要 %zu 个参数,提供了 %zu 个", + node->as.method_call.method_name, + sym->param_count > 0 ? sym->param_count - 1 : 0, + node->as.method_call.arg_count); + node->type.kind = TYPE_ERROR; return; + } + if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) { + node->type.kind = TYPE_ERROR; return; + } + for (size_t i = 0; i < node->as.method_call.arg_count; i++) { + analyze_expr(node->as.method_call.args[i], scope, errors, a); + check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1], + sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL, + i, node, sym, errors, a); + } + node->type.kind = sym->return_type; + if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) + node->type.struct_name = sym->return_struct_type_name; +} + +// === 表达式类型检查(调度器) === +void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + switch (node->kind) { + case AST_LITERAL_EXPR: break; + case AST_IDENT_EXPR: analyze_ident_expr(node, scope, errors, a); break; + case AST_UNARY_EXPR: analyze_unary_expr(node, scope, errors, a); break; + case AST_BINARY_EXPR: analyze_binary_expr(node, scope, errors, a); break; + case AST_CALL_EXPR: analyze_call_expr(node, scope, errors, a); break; + case AST_FIELD_ACCESS: analyze_field_access(node, scope, errors, a); break; + case AST_STRUCT_INIT: analyze_struct_init(node, scope, errors, a); break; + case AST_ENUM_VARIANT: analyze_enum_variant(node, scope, errors, a); break; + case AST_INDEX_EXPR: analyze_index_expr(node, scope, errors, a); break; + case AST_METHOD_CALL: analyze_method_call(node, scope, errors, a); break; + case AST_IF_STMT: analyze_node(node, scope, errors, a); break; + case AST_BLOCK: analyze_node(node, scope, errors, a); break; + default: break; + } +}