refactor: sema.c + codegen.c 拆分,全部源文件 <800 行
sema.c 1129行 → sema.c 499行 + typeck.c 629行 + sema_internal.h 51行 - typeck.c: 表达式类型检查 (10个analyze_*函数) + 泛型单态化 + 类型关系 - sema.c: analyze_node + sema_analyze codegen.c 947行 → codegen.c 453行 + cg_expr.c 440行 + codegen_internal.h 83行 - cg_expr.c: LLVM表达式生成 + 类型映射 (to_llvm_type/coerce_int/type_info_to_llvm) - codegen.c: 语句生成 + 模块入口 + 符号表 + 内存清理 全部核心源文件 <800 行限制: parser(662+498), sema(499+629), codegen(453+440) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,440 @@
|
||||
#include "codegen_internal.h"
|
||||
|
||||
LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
|
||||
switch (kind) {
|
||||
case TYPE_I32: return LLVMInt32TypeInContext(ctx->context);
|
||||
case TYPE_I64: return LLVMInt64TypeInContext(ctx->context);
|
||||
case TYPE_U64: return LLVMInt64TypeInContext(ctx->context);
|
||||
case TYPE_F64: return LLVMDoubleTypeInContext(ctx->context);
|
||||
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
||||
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
||||
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
||||
case TYPE_STRUCT:
|
||||
case TYPE_ENUM: {
|
||||
// tagged union: { i64 tag, i64 payload }
|
||||
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
|
||||
LLVMInt64TypeInContext(ctx->context) };
|
||||
return LLVMStructTypeInContext(ctx->context, fields, 2, false);
|
||||
}
|
||||
case TYPE_UNKNOWN:
|
||||
case TYPE_ERROR:
|
||||
default: return LLVMVoidTypeInContext(ctx->context);
|
||||
}
|
||||
}
|
||||
|
||||
LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) {
|
||||
switch (lit->as.literal.lit_type) {
|
||||
case TYPE_I32:
|
||||
case TYPE_I64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, true);
|
||||
case TYPE_U64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
|
||||
case TYPE_CHAR: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
|
||||
case TYPE_F64: return LLVMConstReal(ty, lit->as.literal.f64_val);
|
||||
case TYPE_BOOL: return LLVMConstInt(ty, lit->as.literal.bool_val ? 1 : 0, false);
|
||||
default: return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val,
|
||||
LLVMTypeRef from_ty, LLVMTypeRef to_ty) {
|
||||
if (from_ty == to_ty) return val;
|
||||
int from_w = LLVMGetIntTypeWidth(from_ty);
|
||||
int to_w = LLVMGetIntTypeWidth(to_ty);
|
||||
if (from_w < to_w)
|
||||
return LLVMBuildSExt(ctx->builder, val, to_ty, "sext");
|
||||
else
|
||||
return LLVMBuildTrunc(ctx->builder, val, to_ty, "trunc");
|
||||
}
|
||||
|
||||
// 从 TypeInfo 生成 LLVM 类型(支持数组、结构体等复合类型)
|
||||
LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti) {
|
||||
switch (ti->kind) {
|
||||
case TYPE_ARRAY: {
|
||||
TypeInfo elem = { .kind = ti->element_type, .struct_name = ti->element_struct_name };
|
||||
LLVMTypeRef elem_ty = type_info_to_llvm(ctx, &elem);
|
||||
return LLVMArrayType(elem_ty, (unsigned)ti->array_size);
|
||||
}
|
||||
case TYPE_STRUCT:
|
||||
if (ti->struct_name) {
|
||||
LLVMTypeRef st = find_struct_type(ctx, ti->struct_name);
|
||||
if (st) return st;
|
||||
}
|
||||
return LLVMVoidTypeInContext(ctx->context);
|
||||
case TYPE_ENUM: {
|
||||
LLVMTypeRef f[] = { LLVMInt64TypeInContext(ctx->context),
|
||||
LLVMInt64TypeInContext(ctx->context) };
|
||||
return LLVMStructTypeInContext(ctx->context, f, 2, false);
|
||||
}
|
||||
default:
|
||||
return to_llvm_type(ctx, ti->kind);
|
||||
}
|
||||
}
|
||||
|
||||
// === 向前声明 ===
|
||||
LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node);
|
||||
void codegen_stmt(CgCtx* ctx, AstNode* node);
|
||||
|
||||
LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) {
|
||||
if (!node) return NULL;
|
||||
|
||||
switch (node->kind) {
|
||||
case AST_LITERAL_EXPR:
|
||||
if (node->type.kind == TYPE_STR) {
|
||||
return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str");
|
||||
}
|
||||
return to_llvm_const(to_llvm_type(ctx, node->type.kind), node);
|
||||
|
||||
case AST_IDENT_EXPR: {
|
||||
LLVMValueRef ptr = find_var(ctx, node->as.ident.name);
|
||||
if (!ptr) return NULL;
|
||||
LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type);
|
||||
return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load");
|
||||
}
|
||||
|
||||
case AST_UNARY_EXPR: {
|
||||
LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand);
|
||||
if (!operand) return NULL;
|
||||
if (node->as.unary.op == OP_NEG) {
|
||||
if (node->type.kind == TYPE_F64)
|
||||
return LLVMBuildFNeg(ctx->builder, operand, "fneg");
|
||||
else
|
||||
return LLVMBuildNeg(ctx->builder, operand, "ineg");
|
||||
} else {
|
||||
return LLVMBuildNot(ctx->builder, operand, "not");
|
||||
}
|
||||
}
|
||||
|
||||
case AST_BINARY_EXPR: {
|
||||
LLVMValueRef l = codegen_expr(ctx, node->as.binary.left);
|
||||
LLVMValueRef r = codegen_expr(ctx, node->as.binary.right);
|
||||
if (!l || !r) return NULL;
|
||||
|
||||
// 字符串拼接:alloc 栈缓冲区,strcpy + strcat
|
||||
if (node->type.kind == TYPE_STR) {
|
||||
// strlen(left)
|
||||
LLVMValueRef len_l = LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
|
||||
(LLVMValueRef[]){l}, 1, "strlen_l");
|
||||
// strlen(right)
|
||||
LLVMValueRef len_r = LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
|
||||
(LLVMValueRef[]){r}, 1, "strlen_r");
|
||||
// total = len_l + len_r + 1
|
||||
LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total");
|
||||
total = LLVMBuildAdd(ctx->builder, total,
|
||||
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1");
|
||||
// char* buf = malloc(total)
|
||||
LLVMValueRef buf = LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn,
|
||||
(LLVMValueRef[]){total}, 1, "str_buf");
|
||||
// memcpy(buf, left, len_l)
|
||||
LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
|
||||
(LLVMValueRef[]){buf, l, len_l}, 3, "");
|
||||
// memcpy(buf + len_l, right, len_r + 1) -- includes null terminator
|
||||
LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder,
|
||||
LLVMInt8TypeInContext(ctx->context), buf,
|
||||
(LLVMValueRef[]){len_l}, 1, "offset");
|
||||
LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r,
|
||||
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1");
|
||||
LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
|
||||
(LLVMValueRef[]){offset_ptr, r, len_r1}, 3, "");
|
||||
return buf;
|
||||
}
|
||||
|
||||
bool is_float = (node->type.kind == TYPE_F64);
|
||||
|
||||
switch (node->as.binary.op) {
|
||||
case OP_ADD:
|
||||
return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd")
|
||||
: LLVMBuildAdd(ctx->builder, l, r, "iadd");
|
||||
case OP_SUB:
|
||||
return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub")
|
||||
: LLVMBuildSub(ctx->builder, l, r, "isub");
|
||||
case OP_MUL:
|
||||
return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul")
|
||||
: LLVMBuildMul(ctx->builder, l, r, "imul");
|
||||
case OP_DIV:
|
||||
return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv")
|
||||
: LLVMBuildSDiv(ctx->builder, l, r, "sdiv");
|
||||
case OP_MOD:
|
||||
return LLVMBuildSRem(ctx->builder, l, r, "srem");
|
||||
case OP_EQ:
|
||||
case OP_NE:
|
||||
case OP_LT:
|
||||
case OP_GT:
|
||||
case OP_LE:
|
||||
case OP_GE: {
|
||||
// 枚举比较: 提取 tag 字段再比较
|
||||
LLVMValueRef cl = l, cr = r;
|
||||
if (node->as.binary.left->type.kind == TYPE_ENUM) {
|
||||
cl = LLVMBuildExtractValue(ctx->builder, l, 0, "enum_tag_l");
|
||||
}
|
||||
if (node->as.binary.right->type.kind == TYPE_ENUM) {
|
||||
cr = LLVMBuildExtractValue(ctx->builder, r, 0, "enum_tag_r");
|
||||
}
|
||||
LLVMIntPredicate pred;
|
||||
switch (node->as.binary.op) {
|
||||
case OP_EQ: pred = LLVMIntEQ; break;
|
||||
case OP_NE: pred = LLVMIntNE; break;
|
||||
case OP_LT: pred = LLVMIntSLT; break;
|
||||
case OP_GT: pred = LLVMIntSGT; break;
|
||||
case OP_LE: pred = LLVMIntSLE; break;
|
||||
case OP_GE: pred = LLVMIntSGE; break;
|
||||
default: return NULL;
|
||||
}
|
||||
if (is_float)
|
||||
return LLVMBuildFCmp(ctx->builder, pred == LLVMIntEQ ? LLVMRealOEQ :
|
||||
pred == LLVMIntNE ? LLVMRealONE : pred == LLVMIntSLT ? LLVMRealOLT :
|
||||
pred == LLVMIntSGT ? LLVMRealOGT : pred == LLVMIntSLE ? LLVMRealOLE :
|
||||
LLVMRealOGE, cl, cr, "fcmp");
|
||||
return LLVMBuildICmp(ctx->builder, pred, cl, cr, "icmp");
|
||||
}
|
||||
case OP_AND:
|
||||
return LLVMBuildAnd(ctx->builder, l, r, "and");
|
||||
case OP_OR:
|
||||
return LLVMBuildOr(ctx->builder, l, r, "or");
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
case AST_CALL_EXPR: {
|
||||
// === 内置 print 函数:委托给 printf ===
|
||||
if (strcmp(node->as.call.name, "print_i64") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
// 枚举类型: 提取 tag 字段
|
||||
if (node->as.call.args[0]->type.kind == TYPE_ENUM)
|
||||
arg = LLVMBuildExtractValue(ctx->builder, arg, 0, "tag");
|
||||
LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context);
|
||||
arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty);
|
||||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64");
|
||||
LLVMValueRef printf_args[] = { fmt, arg };
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
printf_args, 2, "");
|
||||
}
|
||||
if (strcmp(node->as.call.name, "print_f64") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64");
|
||||
LLVMValueRef printf_args[] = { fmt, arg };
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
printf_args, 2, "");
|
||||
}
|
||||
if (strcmp(node->as.call.name, "print_bool") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
// 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择
|
||||
LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg,
|
||||
LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp");
|
||||
LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str");
|
||||
LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str");
|
||||
LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel");
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
(LLVMValueRef[]){selected}, 1, "");
|
||||
}
|
||||
if (strcmp(node->as.call.name, "print_str") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str");
|
||||
LLVMValueRef printf_args[] = { fmt, arg };
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
printf_args, 2, "");
|
||||
}
|
||||
|
||||
// === 常规函数调用 ===
|
||||
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
||||
if (!fn) return NULL;
|
||||
LLVMValueRef args[16];
|
||||
if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; }
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||||
args[i] = codegen_expr(ctx, node->as.call.args[i]);
|
||||
if (!args[i]) return NULL;
|
||||
}
|
||||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
||||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||||
return LLVMBuildCall2(ctx->builder, fn_ty, fn,
|
||||
args, (unsigned)node->as.call.arg_count,
|
||||
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
|
||||
}
|
||||
|
||||
// === 结构体字段访问: p.x ===
|
||||
case AST_FIELD_ACCESS: {
|
||||
// 对对象求值(返回的是 struct 值)
|
||||
LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object);
|
||||
if (!struct_val) return NULL;
|
||||
|
||||
int field_idx = node->as.field_access.field_index;
|
||||
if (field_idx < 0) return NULL; // sema 应当已经设置
|
||||
|
||||
// 用 extractvalue 从结构体值中提取字段
|
||||
return LLVMBuildExtractValue(ctx->builder, struct_val,
|
||||
(unsigned)field_idx, node->as.field_access.field);
|
||||
}
|
||||
|
||||
// === 结构体初始化: Point { x: 10, y: 20 } ===
|
||||
case AST_STRUCT_INIT: {
|
||||
const char* st_name = node->as.struct_init.type_name;
|
||||
LLVMTypeRef struct_ty = find_struct_type(ctx, st_name);
|
||||
if (!struct_ty) return NULL;
|
||||
|
||||
// alloca 分配结构体空间
|
||||
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init");
|
||||
|
||||
// 获取结构体字段名列表(从 struct_table 或从 AST 中)
|
||||
// 对每个 init 字段,找到它在结构体中的索引并 store
|
||||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||||
AstNode* fval = node->as.struct_init.field_values[i];
|
||||
LLVMValueRef val = codegen_expr(ctx, fval);
|
||||
if (!val) return NULL;
|
||||
|
||||
// 获取字段指针: GEP struct_ty, alloca, 0, i
|
||||
LLVMValueRef indices[] = {
|
||||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false)
|
||||
};
|
||||
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca,
|
||||
indices, 2, "field_ptr");
|
||||
LLVMBuildStore(ctx->builder, val, field_ptr);
|
||||
}
|
||||
|
||||
// 加载整个结构体值
|
||||
return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val");
|
||||
}
|
||||
|
||||
case AST_ENUM_VARIANT: {
|
||||
// tagged union: { tag, payload }
|
||||
LLVMValueRef tag = LLVMConstInt(LLVMInt64TypeInContext(ctx->context),
|
||||
(unsigned long long)node->as.enum_variant.variant_index, true);
|
||||
LLVMValueRef payload = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 0, true);
|
||||
if (node->as.enum_variant.payload) {
|
||||
LLVMValueRef pv = codegen_expr(ctx, node->as.enum_variant.payload);
|
||||
if (pv) {
|
||||
// 将 payload 强制转换为 i64
|
||||
LLVMTypeRef pv_ty = LLVMTypeOf(pv);
|
||||
LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context);
|
||||
if (pv_ty != i64_ty && LLVMGetTypeKind(pv_ty) == LLVMIntegerTypeKind)
|
||||
pv = coerce_int(ctx, pv, pv_ty, i64_ty);
|
||||
payload = pv;
|
||||
}
|
||||
}
|
||||
LLVMValueRef fields[] = { tag, payload };
|
||||
return LLVMConstStruct(fields, 2, false);
|
||||
}
|
||||
|
||||
case AST_METHOD_CALL: {
|
||||
const char* struct_name = node->as.method_call.receiver->type.struct_name;
|
||||
char mangled[256];
|
||||
// 若 method_name 已含 $(trait 方法,sema 已设置全限定名),直接用
|
||||
if (strchr(node->as.method_call.method_name, '$'))
|
||||
snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name);
|
||||
else
|
||||
snprintf(mangled, sizeof(mangled), "%s$%s", struct_name,
|
||||
node->as.method_call.method_name);
|
||||
LLVMValueRef fn = find_fn(ctx, mangled);
|
||||
if (!fn) return NULL;
|
||||
// 参数列表: [receiver, 用户参数...]
|
||||
if (node->as.method_call.arg_count + 1 > 16) { ctx->error = "方法参数过多(最多15)"; return NULL; }
|
||||
LLVMValueRef args[16];
|
||||
args[0] = codegen_expr(ctx, node->as.method_call.receiver);
|
||||
if (!args[0]) return NULL;
|
||||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||||
args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]);
|
||||
if (!args[i + 1]) return NULL;
|
||||
}
|
||||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
||||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||||
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
|
||||
(unsigned)(node->as.method_call.arg_count + 1),
|
||||
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call");
|
||||
}
|
||||
|
||||
case AST_INDEX_EXPR: {
|
||||
// 获取数组变量的指针
|
||||
AstNode* arr_node = node->as.index_expr.array;
|
||||
LLVMValueRef arr_ptr = NULL;
|
||||
LLVMTypeRef arr_gp_type = NULL;
|
||||
|
||||
if (arr_node->kind == AST_IDENT_EXPR) {
|
||||
arr_ptr = find_var(ctx, arr_node->as.ident.name);
|
||||
// 从变量表获取数组类型用于 GEP
|
||||
for (VarEntry* e = ctx->var_table; e; e = e->next) {
|
||||
if (strcmp(e->name, arr_node->as.ident.name) == 0) {
|
||||
arr_gp_type = e->alloca_type; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!arr_ptr || !arr_gp_type) return NULL;
|
||||
|
||||
// 生成索引值
|
||||
LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index);
|
||||
if (!idx_val) return NULL;
|
||||
|
||||
// GEP 索引必须是 i32,但 L 使用 i64。截断。
|
||||
LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val,
|
||||
LLVMInt32TypeInContext(ctx->context), "idx32");
|
||||
|
||||
LLVMValueRef indices[] = {
|
||||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||||
idx_i32
|
||||
};
|
||||
LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem");
|
||||
|
||||
LLVMTypeRef elem_load_ty;
|
||||
if (node->type.kind == TYPE_STRUCT && node->type.struct_name) {
|
||||
elem_load_ty = find_struct_type(ctx, node->type.struct_name);
|
||||
if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind);
|
||||
} else {
|
||||
elem_load_ty = type_info_to_llvm(ctx, &node->type);
|
||||
}
|
||||
return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load");
|
||||
}
|
||||
|
||||
// 块表达式: { stmt*; expr } → 最后表达式的值
|
||||
case AST_BLOCK: {
|
||||
LLVMValueRef result = NULL;
|
||||
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
|
||||
AstNode* stmt = node->as.block.stmts[i];
|
||||
bool is_last = (i == node->as.block.stmt_count - 1);
|
||||
if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) {
|
||||
result = codegen_expr(ctx, stmt->as.expr_stmt.expr);
|
||||
} else {
|
||||
codegen_stmt(ctx, stmt);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// if 表达式: if cond { a } else { b }
|
||||
case AST_IF_STMT: {
|
||||
if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; }
|
||||
LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond);
|
||||
if (!cond_val) return NULL;
|
||||
LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type);
|
||||
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res");
|
||||
LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder));
|
||||
LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then");
|
||||
LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else");
|
||||
LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge");
|
||||
LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, then_bb);
|
||||
LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block);
|
||||
if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca);
|
||||
LLVMBuildBr(ctx->builder, merge_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, else_bb);
|
||||
LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block);
|
||||
if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca);
|
||||
LLVMBuildBr(ctx->builder, merge_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, merge_bb);
|
||||
return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val");
|
||||
}
|
||||
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
+11
-505
@@ -1,104 +1,15 @@
|
||||
#include "codegen.h"
|
||||
#include <llvm-c/Analysis.h>
|
||||
#include <llvm-c/Types.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include "codegen_internal.h"
|
||||
|
||||
// === 递归深度限制
|
||||
static int codegen_depth = 0;
|
||||
#define MAX_CODEGEN_DEPTH 1000
|
||||
|
||||
// === 内部状态 ===
|
||||
typedef struct VarEntry {
|
||||
const char* name;
|
||||
LLVMValueRef alloca;
|
||||
LLVMTypeRef alloca_type; // 分配的类型(GEP 需要)
|
||||
struct VarEntry* next;
|
||||
} VarEntry;
|
||||
|
||||
typedef struct FnEntry {
|
||||
const char* name;
|
||||
LLVMValueRef fn;
|
||||
TypeKind ret;
|
||||
TypeKind* params;
|
||||
size_t pc;
|
||||
struct FnEntry* next;
|
||||
} FnEntry;
|
||||
|
||||
// 结构体类型映射
|
||||
typedef struct StructTypeEntry {
|
||||
const char* name;
|
||||
LLVMTypeRef llvm_type;
|
||||
size_t field_count;
|
||||
struct StructTypeEntry* next;
|
||||
} StructTypeEntry;
|
||||
|
||||
typedef struct {
|
||||
Arena* arena; // 代码生成阶段分配器
|
||||
LLVMContextRef context; // LLVM 19+ 需要显式 Context
|
||||
LLVMModuleRef module;
|
||||
LLVMBuilderRef builder;
|
||||
VarEntry* var_table;
|
||||
const char* error;
|
||||
FnEntry* fn_table;
|
||||
StructTypeEntry* struct_table;
|
||||
// printf 运行时支持(内置 print 函数委托给 printf)
|
||||
LLVMValueRef printf_fn;
|
||||
LLVMTypeRef printf_ty;
|
||||
// 字符串拼接运行时支持
|
||||
LLVMValueRef malloc_fn;
|
||||
LLVMValueRef free_fn; // auto-free 需要 free()
|
||||
LLVMValueRef strlen_fn;
|
||||
LLVMValueRef memcpy_fn;
|
||||
// 自动内存管理: 追踪需要 free 的 str alloca
|
||||
LLVMValueRef* cleanup_list;
|
||||
size_t cleanup_count;
|
||||
size_t cleanup_cap;
|
||||
} CgCtx;
|
||||
|
||||
// === 类型映射(需要 Context)===
|
||||
static LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
|
||||
switch (kind) {
|
||||
case TYPE_I32: return LLVMInt32TypeInContext(ctx->context);
|
||||
case TYPE_I64: return LLVMInt64TypeInContext(ctx->context);
|
||||
case TYPE_U64: return LLVMInt64TypeInContext(ctx->context);
|
||||
case TYPE_F64: return LLVMDoubleTypeInContext(ctx->context);
|
||||
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
||||
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
||||
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
||||
case TYPE_STRUCT:
|
||||
case TYPE_ENUM: {
|
||||
// tagged union: { i64 tag, i64 payload }
|
||||
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
|
||||
LLVMInt64TypeInContext(ctx->context) };
|
||||
return LLVMStructTypeInContext(ctx->context, fields, 2, false);
|
||||
}
|
||||
case TYPE_UNKNOWN:
|
||||
case TYPE_ERROR:
|
||||
default: return LLVMVoidTypeInContext(ctx->context);
|
||||
}
|
||||
}
|
||||
|
||||
static LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit) {
|
||||
switch (lit->as.literal.lit_type) {
|
||||
case TYPE_I32:
|
||||
case TYPE_I64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, true);
|
||||
case TYPE_U64: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
|
||||
case TYPE_CHAR: return LLVMConstInt(ty, (unsigned long long)lit->as.literal.i64_val, false);
|
||||
case TYPE_F64: return LLVMConstReal(ty, lit->as.literal.f64_val);
|
||||
case TYPE_BOOL: return LLVMConstInt(ty, lit->as.literal.bool_val ? 1 : 0, false);
|
||||
default: return NULL;
|
||||
}
|
||||
}
|
||||
int codegen_depth = 0;
|
||||
|
||||
// === 变量表 ===
|
||||
static LLVMValueRef find_var(CgCtx* ctx, const char* name) {
|
||||
LLVMValueRef find_var(CgCtx* ctx, const char* name) {
|
||||
for (VarEntry* e = ctx->var_table; e; e = e->next)
|
||||
if (strcmp(e->name, name) == 0) return e->alloca;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
|
||||
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
|
||||
VarEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||||
if (!e) return;
|
||||
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table;
|
||||
@@ -106,13 +17,13 @@ static void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeR
|
||||
}
|
||||
|
||||
// === 函数表 ===
|
||||
static LLVMValueRef find_fn(CgCtx* ctx, const char* name) {
|
||||
LLVMValueRef find_fn(CgCtx* ctx, const char* name) {
|
||||
for (FnEntry* e = ctx->fn_table; e; e = e->next)
|
||||
if (strcmp(e->name, name) == 0) return e->fn;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) {
|
||||
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) {
|
||||
FnEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||||
if (!e) return;
|
||||
e->name = name; e->fn = fn;
|
||||
@@ -124,7 +35,7 @@ static void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) {
|
||||
}
|
||||
|
||||
// === 结构体类型表 ===
|
||||
static void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) {
|
||||
void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) {
|
||||
StructTypeEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||||
if (!e) return;
|
||||
e->name = name; e->llvm_type = ty; e->field_count = fc;
|
||||
@@ -132,420 +43,15 @@ static void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t
|
||||
ctx->struct_table = e;
|
||||
}
|
||||
|
||||
static LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name) {
|
||||
LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name) {
|
||||
for (StructTypeEntry* e = ctx->struct_table; e; e = e->next)
|
||||
if (strcmp(e->name, name) == 0) return e->llvm_type;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// 将整数值强制转换到目标 LLVM 类型(sext/zext/trunc)
|
||||
static LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val,
|
||||
LLVMTypeRef from_ty, LLVMTypeRef to_ty) {
|
||||
if (from_ty == to_ty) return val;
|
||||
int from_w = LLVMGetIntTypeWidth(from_ty);
|
||||
int to_w = LLVMGetIntTypeWidth(to_ty);
|
||||
if (from_w < to_w)
|
||||
return LLVMBuildSExt(ctx->builder, val, to_ty, "sext");
|
||||
else
|
||||
return LLVMBuildTrunc(ctx->builder, val, to_ty, "trunc");
|
||||
}
|
||||
|
||||
// 从 TypeInfo 生成 LLVM 类型(支持数组、结构体等复合类型)
|
||||
static LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti) {
|
||||
switch (ti->kind) {
|
||||
case TYPE_ARRAY: {
|
||||
TypeInfo elem = { .kind = ti->element_type, .struct_name = ti->element_struct_name };
|
||||
LLVMTypeRef elem_ty = type_info_to_llvm(ctx, &elem);
|
||||
return LLVMArrayType(elem_ty, (unsigned)ti->array_size);
|
||||
}
|
||||
case TYPE_STRUCT:
|
||||
if (ti->struct_name) {
|
||||
LLVMTypeRef st = find_struct_type(ctx, ti->struct_name);
|
||||
if (st) return st;
|
||||
}
|
||||
return LLVMVoidTypeInContext(ctx->context);
|
||||
case TYPE_ENUM: {
|
||||
LLVMTypeRef f[] = { LLVMInt64TypeInContext(ctx->context),
|
||||
LLVMInt64TypeInContext(ctx->context) };
|
||||
return LLVMStructTypeInContext(ctx->context, f, 2, false);
|
||||
}
|
||||
default:
|
||||
return to_llvm_type(ctx, ti->kind);
|
||||
}
|
||||
}
|
||||
|
||||
// === 向前声明 ===
|
||||
static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node);
|
||||
static void codegen_stmt(CgCtx* ctx, AstNode* node);
|
||||
|
||||
// === 表达式代码生成 ===
|
||||
static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) {
|
||||
if (!node) return NULL;
|
||||
|
||||
switch (node->kind) {
|
||||
case AST_LITERAL_EXPR:
|
||||
if (node->type.kind == TYPE_STR) {
|
||||
return LLVMBuildGlobalStringPtr(ctx->builder, node->as.literal.str_val, "str");
|
||||
}
|
||||
return to_llvm_const(to_llvm_type(ctx, node->type.kind), node);
|
||||
|
||||
case AST_IDENT_EXPR: {
|
||||
LLVMValueRef ptr = find_var(ctx, node->as.ident.name);
|
||||
if (!ptr) return NULL;
|
||||
LLVMTypeRef load_ty = type_info_to_llvm(ctx, &node->type);
|
||||
return LLVMBuildLoad2(ctx->builder, load_ty, ptr, "load");
|
||||
}
|
||||
|
||||
case AST_UNARY_EXPR: {
|
||||
LLVMValueRef operand = codegen_expr(ctx, node->as.unary.operand);
|
||||
if (!operand) return NULL;
|
||||
if (node->as.unary.op == OP_NEG) {
|
||||
if (node->type.kind == TYPE_F64)
|
||||
return LLVMBuildFNeg(ctx->builder, operand, "fneg");
|
||||
else
|
||||
return LLVMBuildNeg(ctx->builder, operand, "ineg");
|
||||
} else {
|
||||
return LLVMBuildNot(ctx->builder, operand, "not");
|
||||
}
|
||||
}
|
||||
|
||||
case AST_BINARY_EXPR: {
|
||||
LLVMValueRef l = codegen_expr(ctx, node->as.binary.left);
|
||||
LLVMValueRef r = codegen_expr(ctx, node->as.binary.right);
|
||||
if (!l || !r) return NULL;
|
||||
|
||||
// 字符串拼接:alloc 栈缓冲区,strcpy + strcat
|
||||
if (node->type.kind == TYPE_STR) {
|
||||
// strlen(left)
|
||||
LLVMValueRef len_l = LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
|
||||
(LLVMValueRef[]){l}, 1, "strlen_l");
|
||||
// strlen(right)
|
||||
LLVMValueRef len_r = LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->strlen_fn), ctx->strlen_fn,
|
||||
(LLVMValueRef[]){r}, 1, "strlen_r");
|
||||
// total = len_l + len_r + 1
|
||||
LLVMValueRef total = LLVMBuildAdd(ctx->builder, len_l, len_r, "total");
|
||||
total = LLVMBuildAdd(ctx->builder, total,
|
||||
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "total_1");
|
||||
// char* buf = malloc(total)
|
||||
LLVMValueRef buf = LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn,
|
||||
(LLVMValueRef[]){total}, 1, "str_buf");
|
||||
// memcpy(buf, left, len_l)
|
||||
LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
|
||||
(LLVMValueRef[]){buf, l, len_l}, 3, "");
|
||||
// memcpy(buf + len_l, right, len_r + 1) -- includes null terminator
|
||||
LLVMValueRef offset_ptr = LLVMBuildGEP2(ctx->builder,
|
||||
LLVMInt8TypeInContext(ctx->context), buf,
|
||||
(LLVMValueRef[]){len_l}, 1, "offset");
|
||||
LLVMValueRef len_r1 = LLVMBuildAdd(ctx->builder, len_r,
|
||||
LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 1, false), "len_r1");
|
||||
LLVMBuildCall2(ctx->builder,
|
||||
LLVMGlobalGetValueType(ctx->memcpy_fn), ctx->memcpy_fn,
|
||||
(LLVMValueRef[]){offset_ptr, r, len_r1}, 3, "");
|
||||
return buf;
|
||||
}
|
||||
|
||||
bool is_float = (node->type.kind == TYPE_F64);
|
||||
|
||||
switch (node->as.binary.op) {
|
||||
case OP_ADD:
|
||||
return is_float ? LLVMBuildFAdd(ctx->builder, l, r, "fadd")
|
||||
: LLVMBuildAdd(ctx->builder, l, r, "iadd");
|
||||
case OP_SUB:
|
||||
return is_float ? LLVMBuildFSub(ctx->builder, l, r, "fsub")
|
||||
: LLVMBuildSub(ctx->builder, l, r, "isub");
|
||||
case OP_MUL:
|
||||
return is_float ? LLVMBuildFMul(ctx->builder, l, r, "fmul")
|
||||
: LLVMBuildMul(ctx->builder, l, r, "imul");
|
||||
case OP_DIV:
|
||||
return is_float ? LLVMBuildFDiv(ctx->builder, l, r, "fdiv")
|
||||
: LLVMBuildSDiv(ctx->builder, l, r, "sdiv");
|
||||
case OP_MOD:
|
||||
return LLVMBuildSRem(ctx->builder, l, r, "srem");
|
||||
case OP_EQ:
|
||||
case OP_NE:
|
||||
case OP_LT:
|
||||
case OP_GT:
|
||||
case OP_LE:
|
||||
case OP_GE: {
|
||||
// 枚举比较: 提取 tag 字段再比较
|
||||
LLVMValueRef cl = l, cr = r;
|
||||
if (node->as.binary.left->type.kind == TYPE_ENUM) {
|
||||
cl = LLVMBuildExtractValue(ctx->builder, l, 0, "enum_tag_l");
|
||||
}
|
||||
if (node->as.binary.right->type.kind == TYPE_ENUM) {
|
||||
cr = LLVMBuildExtractValue(ctx->builder, r, 0, "enum_tag_r");
|
||||
}
|
||||
LLVMIntPredicate pred;
|
||||
switch (node->as.binary.op) {
|
||||
case OP_EQ: pred = LLVMIntEQ; break;
|
||||
case OP_NE: pred = LLVMIntNE; break;
|
||||
case OP_LT: pred = LLVMIntSLT; break;
|
||||
case OP_GT: pred = LLVMIntSGT; break;
|
||||
case OP_LE: pred = LLVMIntSLE; break;
|
||||
case OP_GE: pred = LLVMIntSGE; break;
|
||||
default: return NULL;
|
||||
}
|
||||
if (is_float)
|
||||
return LLVMBuildFCmp(ctx->builder, pred == LLVMIntEQ ? LLVMRealOEQ :
|
||||
pred == LLVMIntNE ? LLVMRealONE : pred == LLVMIntSLT ? LLVMRealOLT :
|
||||
pred == LLVMIntSGT ? LLVMRealOGT : pred == LLVMIntSLE ? LLVMRealOLE :
|
||||
LLVMRealOGE, cl, cr, "fcmp");
|
||||
return LLVMBuildICmp(ctx->builder, pred, cl, cr, "icmp");
|
||||
}
|
||||
case OP_AND:
|
||||
return LLVMBuildAnd(ctx->builder, l, r, "and");
|
||||
case OP_OR:
|
||||
return LLVMBuildOr(ctx->builder, l, r, "or");
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
case AST_CALL_EXPR: {
|
||||
// === 内置 print 函数:委托给 printf ===
|
||||
if (strcmp(node->as.call.name, "print_i64") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
// 枚举类型: 提取 tag 字段
|
||||
if (node->as.call.args[0]->type.kind == TYPE_ENUM)
|
||||
arg = LLVMBuildExtractValue(ctx->builder, arg, 0, "tag");
|
||||
LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context);
|
||||
arg = coerce_int(ctx, arg, LLVMTypeOf(arg), i64_ty);
|
||||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%lld\n", "fmt_i64");
|
||||
LLVMValueRef printf_args[] = { fmt, arg };
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
printf_args, 2, "");
|
||||
}
|
||||
if (strcmp(node->as.call.name, "print_f64") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%f\n", "fmt_f64");
|
||||
LLVMValueRef printf_args[] = { fmt, arg };
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
printf_args, 2, "");
|
||||
}
|
||||
if (strcmp(node->as.call.name, "print_bool") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
// 将 bool 转为字符串:通过 select 在 "true\n" 和 "false\n" 之间选择
|
||||
LLVMValueRef c = LLVMBuildICmp(ctx->builder, LLVMIntEQ, arg,
|
||||
LLVMConstInt(LLVMInt1TypeInContext(ctx->context), 1, false), "bool_cmp");
|
||||
LLVMValueRef true_str = LLVMBuildGlobalStringPtr(ctx->builder, "true\n", "true_str");
|
||||
LLVMValueRef false_str = LLVMBuildGlobalStringPtr(ctx->builder, "false\n", "false_str");
|
||||
LLVMValueRef selected = LLVMBuildSelect(ctx->builder, c, true_str, false_str, "bool_sel");
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
(LLVMValueRef[]){selected}, 1, "");
|
||||
}
|
||||
if (strcmp(node->as.call.name, "print_str") == 0) {
|
||||
LLVMValueRef arg = codegen_expr(ctx, node->as.call.args[0]);
|
||||
if (!arg) return NULL;
|
||||
LLVMValueRef fmt = LLVMBuildGlobalStringPtr(ctx->builder, "%s\n", "fmt_str");
|
||||
LLVMValueRef printf_args[] = { fmt, arg };
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
printf_args, 2, "");
|
||||
}
|
||||
|
||||
// === 常规函数调用 ===
|
||||
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
||||
if (!fn) return NULL;
|
||||
LLVMValueRef args[16];
|
||||
if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; }
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||||
args[i] = codegen_expr(ctx, node->as.call.args[i]);
|
||||
if (!args[i]) return NULL;
|
||||
}
|
||||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
||||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||||
return LLVMBuildCall2(ctx->builder, fn_ty, fn,
|
||||
args, (unsigned)node->as.call.arg_count,
|
||||
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
|
||||
}
|
||||
|
||||
// === 结构体字段访问: p.x ===
|
||||
case AST_FIELD_ACCESS: {
|
||||
// 对对象求值(返回的是 struct 值)
|
||||
LLVMValueRef struct_val = codegen_expr(ctx, node->as.field_access.object);
|
||||
if (!struct_val) return NULL;
|
||||
|
||||
int field_idx = node->as.field_access.field_index;
|
||||
if (field_idx < 0) return NULL; // sema 应当已经设置
|
||||
|
||||
// 用 extractvalue 从结构体值中提取字段
|
||||
return LLVMBuildExtractValue(ctx->builder, struct_val,
|
||||
(unsigned)field_idx, node->as.field_access.field);
|
||||
}
|
||||
|
||||
// === 结构体初始化: Point { x: 10, y: 20 } ===
|
||||
case AST_STRUCT_INIT: {
|
||||
const char* st_name = node->as.struct_init.type_name;
|
||||
LLVMTypeRef struct_ty = find_struct_type(ctx, st_name);
|
||||
if (!struct_ty) return NULL;
|
||||
|
||||
// alloca 分配结构体空间
|
||||
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, struct_ty, "struct_init");
|
||||
|
||||
// 获取结构体字段名列表(从 struct_table 或从 AST 中)
|
||||
// 对每个 init 字段,找到它在结构体中的索引并 store
|
||||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||||
AstNode* fval = node->as.struct_init.field_values[i];
|
||||
LLVMValueRef val = codegen_expr(ctx, fval);
|
||||
if (!val) return NULL;
|
||||
|
||||
// 获取字段指针: GEP struct_ty, alloca, 0, i
|
||||
LLVMValueRef indices[] = {
|
||||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned long long)i, false)
|
||||
};
|
||||
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, struct_ty, alloca,
|
||||
indices, 2, "field_ptr");
|
||||
LLVMBuildStore(ctx->builder, val, field_ptr);
|
||||
}
|
||||
|
||||
// 加载整个结构体值
|
||||
return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val");
|
||||
}
|
||||
|
||||
case AST_ENUM_VARIANT: {
|
||||
// tagged union: { tag, payload }
|
||||
LLVMValueRef tag = LLVMConstInt(LLVMInt64TypeInContext(ctx->context),
|
||||
(unsigned long long)node->as.enum_variant.variant_index, true);
|
||||
LLVMValueRef payload = LLVMConstInt(LLVMInt64TypeInContext(ctx->context), 0, true);
|
||||
if (node->as.enum_variant.payload) {
|
||||
LLVMValueRef pv = codegen_expr(ctx, node->as.enum_variant.payload);
|
||||
if (pv) {
|
||||
// 将 payload 强制转换为 i64
|
||||
LLVMTypeRef pv_ty = LLVMTypeOf(pv);
|
||||
LLVMTypeRef i64_ty = LLVMInt64TypeInContext(ctx->context);
|
||||
if (pv_ty != i64_ty && LLVMGetTypeKind(pv_ty) == LLVMIntegerTypeKind)
|
||||
pv = coerce_int(ctx, pv, pv_ty, i64_ty);
|
||||
payload = pv;
|
||||
}
|
||||
}
|
||||
LLVMValueRef fields[] = { tag, payload };
|
||||
return LLVMConstStruct(fields, 2, false);
|
||||
}
|
||||
|
||||
case AST_METHOD_CALL: {
|
||||
const char* struct_name = node->as.method_call.receiver->type.struct_name;
|
||||
char mangled[256];
|
||||
// 若 method_name 已含 $(trait 方法,sema 已设置全限定名),直接用
|
||||
if (strchr(node->as.method_call.method_name, '$'))
|
||||
snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name);
|
||||
else
|
||||
snprintf(mangled, sizeof(mangled), "%s$%s", struct_name,
|
||||
node->as.method_call.method_name);
|
||||
LLVMValueRef fn = find_fn(ctx, mangled);
|
||||
if (!fn) return NULL;
|
||||
// 参数列表: [receiver, 用户参数...]
|
||||
if (node->as.method_call.arg_count + 1 > 16) { ctx->error = "方法参数过多(最多15)"; return NULL; }
|
||||
LLVMValueRef args[16];
|
||||
args[0] = codegen_expr(ctx, node->as.method_call.receiver);
|
||||
if (!args[0]) return NULL;
|
||||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||||
args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]);
|
||||
if (!args[i + 1]) return NULL;
|
||||
}
|
||||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
||||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||||
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
|
||||
(unsigned)(node->as.method_call.arg_count + 1),
|
||||
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call");
|
||||
}
|
||||
|
||||
case AST_INDEX_EXPR: {
|
||||
// 获取数组变量的指针
|
||||
AstNode* arr_node = node->as.index_expr.array;
|
||||
LLVMValueRef arr_ptr = NULL;
|
||||
LLVMTypeRef arr_gp_type = NULL;
|
||||
|
||||
if (arr_node->kind == AST_IDENT_EXPR) {
|
||||
arr_ptr = find_var(ctx, arr_node->as.ident.name);
|
||||
// 从变量表获取数组类型用于 GEP
|
||||
for (VarEntry* e = ctx->var_table; e; e = e->next) {
|
||||
if (strcmp(e->name, arr_node->as.ident.name) == 0) {
|
||||
arr_gp_type = e->alloca_type; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!arr_ptr || !arr_gp_type) return NULL;
|
||||
|
||||
// 生成索引值
|
||||
LLVMValueRef idx_val = codegen_expr(ctx, node->as.index_expr.index);
|
||||
if (!idx_val) return NULL;
|
||||
|
||||
// GEP 索引必须是 i32,但 L 使用 i64。截断。
|
||||
LLVMValueRef idx_i32 = LLVMBuildTrunc(ctx->builder, idx_val,
|
||||
LLVMInt32TypeInContext(ctx->context), "idx32");
|
||||
|
||||
LLVMValueRef indices[] = {
|
||||
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||||
idx_i32
|
||||
};
|
||||
LLVMValueRef elem_ptr = LLVMBuildGEP2(ctx->builder, arr_gp_type, arr_ptr, indices, 2, "arr_elem");
|
||||
|
||||
LLVMTypeRef elem_load_ty;
|
||||
if (node->type.kind == TYPE_STRUCT && node->type.struct_name) {
|
||||
elem_load_ty = find_struct_type(ctx, node->type.struct_name);
|
||||
if (!elem_load_ty) elem_load_ty = to_llvm_type(ctx, node->type.kind);
|
||||
} else {
|
||||
elem_load_ty = type_info_to_llvm(ctx, &node->type);
|
||||
}
|
||||
return LLVMBuildLoad2(ctx->builder, elem_load_ty, elem_ptr, "arr_load");
|
||||
}
|
||||
|
||||
// 块表达式: { stmt*; expr } → 最后表达式的值
|
||||
case AST_BLOCK: {
|
||||
LLVMValueRef result = NULL;
|
||||
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
|
||||
AstNode* stmt = node->as.block.stmts[i];
|
||||
bool is_last = (i == node->as.block.stmt_count - 1);
|
||||
if (is_last && stmt->kind == AST_EXPR_STMT && node->type.kind != TYPE_VOID) {
|
||||
result = codegen_expr(ctx, stmt->as.expr_stmt.expr);
|
||||
} else {
|
||||
codegen_stmt(ctx, stmt);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// if 表达式: if cond { a } else { b }
|
||||
case AST_IF_STMT: {
|
||||
if (node->type.kind == TYPE_VOID) { codegen_stmt(ctx, node); return NULL; }
|
||||
LLVMValueRef cond_val = codegen_expr(ctx, node->as.if_stmt.cond);
|
||||
if (!cond_val) return NULL;
|
||||
LLVMTypeRef res_ty = type_info_to_llvm(ctx, &node->type);
|
||||
LLVMValueRef alloca = LLVMBuildAlloca(ctx->builder, res_ty, "if_res");
|
||||
LLVMValueRef func = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx->builder));
|
||||
LLVMBasicBlockRef then_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "then");
|
||||
LLVMBasicBlockRef else_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "else");
|
||||
LLVMBasicBlockRef merge_bb = LLVMAppendBasicBlockInContext(ctx->context, func, "if_merge");
|
||||
LLVMBuildCondBr(ctx->builder, cond_val, then_bb, else_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, then_bb);
|
||||
LLVMValueRef then_val = codegen_expr(ctx, node->as.if_stmt.then_block);
|
||||
if (then_val) LLVMBuildStore(ctx->builder, then_val, alloca);
|
||||
LLVMBuildBr(ctx->builder, merge_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, else_bb);
|
||||
LLVMValueRef else_val = codegen_expr(ctx, node->as.if_stmt.else_block);
|
||||
if (else_val) LLVMBuildStore(ctx->builder, else_val, alloca);
|
||||
LLVMBuildBr(ctx->builder, merge_bb);
|
||||
|
||||
LLVMPositionBuilderAtEnd(ctx->builder, merge_bb);
|
||||
return LLVMBuildLoad2(ctx->builder, res_ty, alloca, "if_val");
|
||||
}
|
||||
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
// === 自动内存管理: 作用域退出时释放 str 堆分配 ===
|
||||
static void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) {
|
||||
void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) {
|
||||
if (ctx->cleanup_count >= ctx->cleanup_cap) {
|
||||
size_t new_cap = ctx->cleanup_cap ? ctx->cleanup_cap * 2 : 16;
|
||||
LLVMValueRef* new_list = arena_alloc(ctx->arena, new_cap * sizeof(LLVMValueRef));
|
||||
@@ -559,7 +65,7 @@ static void cleanup_add(CgCtx* ctx, LLVMValueRef alloca) {
|
||||
}
|
||||
|
||||
// 释放从 mark 位置开始的所有 str 变量
|
||||
static void cleanup_emit(CgCtx* ctx, size_t from_mark) {
|
||||
void cleanup_emit(CgCtx* ctx, size_t from_mark) {
|
||||
for (size_t j = from_mark; j < ctx->cleanup_count; j++) {
|
||||
LLVMValueRef ptr = ctx->cleanup_list[j];
|
||||
LLVMValueRef val = LLVMBuildLoad2(ctx->builder,
|
||||
@@ -572,7 +78,7 @@ static void cleanup_emit(CgCtx* ctx, size_t from_mark) {
|
||||
}
|
||||
|
||||
// === 语句代码生成 ===
|
||||
static void codegen_stmt(CgCtx* ctx, AstNode* node) {
|
||||
void codegen_stmt(CgCtx* ctx, AstNode* node) {
|
||||
if (!node) return;
|
||||
|
||||
switch (node->kind) {
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
#ifndef CODEGEN_INTERNAL_H
|
||||
#define CODEGEN_INTERNAL_H
|
||||
|
||||
#include "codegen.h"
|
||||
#include "ast.h"
|
||||
#include "arena.h"
|
||||
#include "l_lang.h"
|
||||
#include <llvm-c/Analysis.h>
|
||||
#include <llvm-c/Types.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
|
||||
// 递归深度限制
|
||||
extern int codegen_depth;
|
||||
#define MAX_CODEGEN_DEPTH 1000
|
||||
|
||||
// === 内部状态 ===
|
||||
typedef struct VarEntry {
|
||||
const char* name;
|
||||
LLVMValueRef alloca;
|
||||
LLVMTypeRef alloca_type;
|
||||
struct VarEntry* next;
|
||||
} VarEntry;
|
||||
|
||||
typedef struct FnEntry {
|
||||
const char* name;
|
||||
LLVMValueRef fn;
|
||||
TypeKind ret;
|
||||
TypeKind* params;
|
||||
size_t pc;
|
||||
struct FnEntry* next;
|
||||
} FnEntry;
|
||||
|
||||
typedef struct StructTypeEntry {
|
||||
const char* name;
|
||||
LLVMTypeRef llvm_type;
|
||||
size_t field_count;
|
||||
struct StructTypeEntry* next;
|
||||
} StructTypeEntry;
|
||||
|
||||
typedef struct {
|
||||
Arena* arena;
|
||||
LLVMContextRef context;
|
||||
LLVMModuleRef module;
|
||||
LLVMBuilderRef builder;
|
||||
VarEntry* var_table;
|
||||
const char* error;
|
||||
FnEntry* fn_table;
|
||||
StructTypeEntry* struct_table;
|
||||
LLVMValueRef printf_fn;
|
||||
LLVMTypeRef printf_ty;
|
||||
LLVMValueRef malloc_fn;
|
||||
LLVMValueRef free_fn;
|
||||
LLVMValueRef strlen_fn;
|
||||
LLVMValueRef memcpy_fn;
|
||||
LLVMValueRef* cleanup_list;
|
||||
size_t cleanup_count;
|
||||
size_t cleanup_cap;
|
||||
} CgCtx;
|
||||
|
||||
// === 类型映射 ===
|
||||
LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind);
|
||||
LLVMValueRef to_llvm_const(LLVMTypeRef ty, AstNode* lit);
|
||||
LLVMTypeRef type_info_to_llvm(CgCtx* ctx, const TypeInfo* ti);
|
||||
LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, LLVMTypeRef from_ty, LLVMTypeRef to_ty);
|
||||
|
||||
// === 表操作 ===
|
||||
LLVMValueRef find_var(CgCtx* ctx, const char* name);
|
||||
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type);
|
||||
LLVMValueRef find_fn(CgCtx* ctx, const char* name);
|
||||
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn);
|
||||
void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc);
|
||||
LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name);
|
||||
|
||||
// === 内存清理 ===
|
||||
void cleanup_add(CgCtx* ctx, LLVMValueRef alloca);
|
||||
void cleanup_emit(CgCtx* ctx, size_t from_mark);
|
||||
|
||||
// === 代码生成函数 ===
|
||||
LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node);
|
||||
void codegen_stmt(CgCtx* ctx, AstNode* node);
|
||||
|
||||
#endif
|
||||
+2
-632
@@ -1,636 +1,6 @@
|
||||
#include "sema.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "sema_internal.h"
|
||||
|
||||
// === 泛型单态化: AST 类型替换 ===
|
||||
// 将 AST 中所有匹配 type_param_name 的类型引用替换为 concrete_type
|
||||
static void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname);
|
||||
|
||||
static void subst_type_info(TypeInfo* ti, const char* tparam, TypeKind concrete, const char* concrete_sname) {
|
||||
if (ti->kind == TYPE_STRUCT && ti->struct_name && strcmp(ti->struct_name, tparam) == 0) {
|
||||
ti->kind = concrete;
|
||||
ti->struct_name = concrete_sname;
|
||||
}
|
||||
}
|
||||
|
||||
static void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname) {
|
||||
if (!node) return;
|
||||
// 替换节点自身类型
|
||||
subst_type_info(&node->type, tparam, concrete, concrete_sname);
|
||||
switch (node->kind) {
|
||||
case AST_PROGRAM:
|
||||
for (size_t i = 0; i < node->as.program.fn_count; i++)
|
||||
subst_ast_types(node->as.program.functions[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_FUNCTION:
|
||||
if (node->as.function.return_type == TYPE_STRUCT
|
||||
&& node->as.function.return_struct_type_name
|
||||
&& strcmp(node->as.function.return_struct_type_name, tparam) == 0) {
|
||||
node->as.function.return_type = concrete;
|
||||
node->as.function.return_struct_type_name = concrete_sname;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.function.param_count; i++)
|
||||
subst_ast_types(node->as.function.params[i], tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.function.body, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_PARAMETER:
|
||||
if (node->as.parameter.type == TYPE_STRUCT && node->as.parameter.struct_type_name
|
||||
&& strcmp(node->as.parameter.struct_type_name, tparam) == 0) {
|
||||
node->as.parameter.type = concrete;
|
||||
node->as.parameter.struct_type_name = concrete_sname;
|
||||
}
|
||||
break;
|
||||
case AST_BLOCK:
|
||||
for (size_t i = 0; i < node->as.block.stmt_count; i++)
|
||||
subst_ast_types(node->as.block.stmts[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_LET_STMT:
|
||||
if (node->as.let_stmt.annot_type == TYPE_STRUCT && node->as.let_stmt.struct_type_name
|
||||
&& strcmp(node->as.let_stmt.struct_type_name, tparam) == 0) {
|
||||
node->as.let_stmt.annot_type = concrete;
|
||||
node->as.let_stmt.struct_type_name = concrete_sname;
|
||||
}
|
||||
subst_ast_types(node->as.let_stmt.init, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_IF_STMT:
|
||||
subst_ast_types(node->as.if_stmt.cond, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.if_stmt.then_block, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.if_stmt.else_block, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_WHILE_STMT:
|
||||
subst_ast_types(node->as.while_stmt.cond, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.while_stmt.body, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_RETURN_STMT:
|
||||
subst_ast_types(node->as.return_stmt.expr, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_EXPR_STMT:
|
||||
subst_ast_types(node->as.expr_stmt.expr, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_BINARY_EXPR:
|
||||
subst_ast_types(node->as.binary.left, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.binary.right, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_UNARY_EXPR:
|
||||
subst_ast_types(node->as.unary.operand, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_CALL_EXPR:
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||||
subst_ast_types(node->as.call.args[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_ASSIGN_STMT:
|
||||
subst_ast_types(node->as.assign_stmt.value, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_FIELD_ACCESS:
|
||||
subst_ast_types(node->as.field_access.object, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_STRUCT_INIT:
|
||||
for (size_t i = 0; i < node->as.struct_init.field_count; i++)
|
||||
subst_ast_types(node->as.struct_init.field_values[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// === 类型关系 ===
|
||||
static TypeKind current_return_type = TYPE_VOID;
|
||||
static const char* current_return_struct_name = NULL;
|
||||
|
||||
// 单态化队列: 泛型函数调用时生成的具象化函数
|
||||
static AstNode* mono_queue[256];
|
||||
static size_t mono_count = 0;
|
||||
static Arena* mono_arena = NULL;
|
||||
static AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板)
|
||||
|
||||
static TypeKind promote(TypeKind a, TypeKind b) {
|
||||
// 枚举在算术运算中视为 i64
|
||||
if (a == TYPE_ENUM) a = TYPE_I64;
|
||||
if (b == TYPE_ENUM) b = TYPE_I64;
|
||||
// char 在算术中提升为 i32
|
||||
if (a == TYPE_CHAR) a = TYPE_I32;
|
||||
if (b == TYPE_CHAR) b = TYPE_I32;
|
||||
if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64;
|
||||
if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64;
|
||||
if (a == TYPE_U64 || b == TYPE_U64) return TYPE_U64;
|
||||
if (a == TYPE_I32 || b == TYPE_I32) return TYPE_I32;
|
||||
if (a == TYPE_BOOL || b == TYPE_BOOL) return TYPE_BOOL;
|
||||
return TYPE_ERROR;
|
||||
}
|
||||
|
||||
static bool is_numeric(TypeKind t) {
|
||||
return t == TYPE_I32 || t == TYPE_I64 || t == TYPE_U64
|
||||
|| t == TYPE_F64 || t == TYPE_CHAR || t == TYPE_ENUM;
|
||||
}
|
||||
// 隐式类型转换规则: 无损加宽允许,有符号→无符号不允许
|
||||
static bool can_implicit_convert(TypeKind from, TypeKind to) {
|
||||
if (from == to) return true;
|
||||
// 枚举视为 i64
|
||||
if (from == TYPE_ENUM) from = TYPE_I64;
|
||||
if (to == TYPE_ENUM) to = TYPE_I64;
|
||||
// char 可转为任意整数
|
||||
if (from == TYPE_CHAR) return to == TYPE_I32 || to == TYPE_I64 || to == TYPE_U64 || to == TYPE_F64;
|
||||
// i32 可加宽
|
||||
if (from == TYPE_I32) return to == TYPE_I64 || to == TYPE_F64;
|
||||
// i64 可转 f64
|
||||
if (from == TYPE_I64) return to == TYPE_F64;
|
||||
// u64 ↔ i64 双向允许(同一位宽,LLVM 同类型)
|
||||
if (from == TYPE_U64) return to == TYPE_F64 || to == TYPE_I64;
|
||||
if (from == TYPE_I64) return to == TYPE_F64 || to == TYPE_U64;
|
||||
return false;
|
||||
}
|
||||
static bool is_comparable(TypeKind a, TypeKind b) {
|
||||
if (a == b) return true;
|
||||
// 枚举可以参与整数比较
|
||||
if ((a == TYPE_I64 && b == TYPE_ENUM) || (a == TYPE_ENUM && b == TYPE_I64)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// === 向前声明 ===
|
||||
static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
|
||||
// === 表达式类型检查辅助函数 ===
|
||||
static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
|
||||
static void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
(void)a;
|
||||
Symbol* sym = scope_lookup(scope, node->as.ident.name);
|
||||
if (!sym) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的变量 '%s'", node->as.ident.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else if (sym->is_type_alias) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'%s' 是类型别名,不能作为表达式使用", node->as.ident.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else if (sym->kind == SYM_FUNCTION) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'%s' 是函数,不能作为表达式使用", node->as.ident.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = sym->type;
|
||||
if (sym->type == TYPE_STRUCT && sym->struct_type_name)
|
||||
node->type.struct_name = sym->struct_type_name;
|
||||
if (sym->type == TYPE_ARRAY) {
|
||||
node->type.element_type = sym->array_element_type;
|
||||
node->type.element_struct_name = sym->array_element_struct_name;
|
||||
node->type.array_size = sym->array_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.unary.operand, scope, errors, a);
|
||||
TypeKind inner = node->as.unary.operand->type.kind;
|
||||
if (node->as.unary.op == OP_NEG) {
|
||||
if (!is_numeric(inner)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"一元 '-' 只能用于数值类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = inner;
|
||||
}
|
||||
} else { // OP_NOT
|
||||
if (inner != TYPE_BOOL) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'!' 只能用于布尔类型,得到 '%s'", type_name(inner));
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_BOOL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.binary.left, scope, errors, a);
|
||||
analyze_expr(node->as.binary.right, scope, errors, a);
|
||||
TypeKind l = node->as.binary.left->type.kind;
|
||||
TypeKind r = node->as.binary.right->type.kind;
|
||||
if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
|
||||
switch (node->as.binary.op) {
|
||||
case OP_ADD:
|
||||
if (l == TYPE_STR || r == TYPE_STR) {
|
||||
if (l != TYPE_STR || r != TYPE_STR) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'",
|
||||
type_name(l), type_name(r));
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_STR;
|
||||
}
|
||||
} else if (!is_numeric(l) || !is_numeric(r)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = promote(l, r);
|
||||
}
|
||||
break;
|
||||
case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD:
|
||||
if (!is_numeric(l) || !is_numeric(r)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = promote(l, r);
|
||||
}
|
||||
break;
|
||||
case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE:
|
||||
if (!is_comparable(l, r)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r));
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_BOOL;
|
||||
}
|
||||
break;
|
||||
case OP_AND: case OP_OR:
|
||||
if (l != TYPE_BOOL || r != TYPE_BOOL) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "逻辑运算需要布尔类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_BOOL;
|
||||
}
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用)
|
||||
static bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname,
|
||||
size_t idx, AstNode* call_node, Symbol* fn_sym,
|
||||
ErrorList* errors, Arena* a) {
|
||||
(void)a;
|
||||
TypeKind actual = arg->type.kind;
|
||||
if (actual == TYPE_ERROR) return false;
|
||||
if (expected == TYPE_STRUCT && expected_sname) {
|
||||
// 检查是否是泛型类型参数(匹配则接受任意类型)
|
||||
if (fn_sym && fn_sym->type_params) {
|
||||
for (size_t t = 0; t < fn_sym->type_param_count; t++) {
|
||||
if (strcmp(expected_sname, fn_sym->type_params[t]) == 0)
|
||||
return true; // 泛型参数,接受任意类型
|
||||
}
|
||||
}
|
||||
const char* actual_name = arg->type.struct_name;
|
||||
if (actual != TYPE_STRUCT || !actual_name ||
|
||||
strcmp(actual_name, expected_sname) != 0) {
|
||||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
idx + 1, expected_sname ? expected_sname : "struct",
|
||||
actual_name ? actual_name : type_name(actual));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (actual == expected) return false;
|
||||
if (expected == TYPE_I64 && actual == TYPE_ENUM) return false;
|
||||
if (can_implicit_convert(actual, expected)) return false;
|
||||
if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR
|
||||
&& (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false;
|
||||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
idx + 1, type_name(expected), type_name(actual));
|
||||
return false;
|
||||
}
|
||||
|
||||
// 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用)
|
||||
static bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
|
||||
ErrorList* errors, const char* call_name) {
|
||||
AstNode** args = node->as.call.args;
|
||||
const char** arg_names = node->as.call.arg_names;
|
||||
size_t arg_count = node->as.call.arg_count;
|
||||
if (!arg_names) return true;
|
||||
AstNode* reordered[16] = {0};
|
||||
for (size_t i = 0; i < arg_count; i++) {
|
||||
if (arg_names[i]) {
|
||||
bool found = false;
|
||||
for (size_t j = param_offset; j < sym->param_count; j++) {
|
||||
if (sym->param_names && sym->param_names[j] &&
|
||||
strcmp(arg_names[i], sym->param_names[j]) == 0) {
|
||||
reordered[j - param_offset] = args[i];
|
||||
found = true; break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
reordered[i] = args[i];
|
||||
}
|
||||
}
|
||||
memcpy(args, reordered, arg_count * sizeof(AstNode*));
|
||||
return true;
|
||||
}
|
||||
|
||||
static void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
Symbol* sym = scope_lookup(scope, node->as.call.name);
|
||||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的函数 '%s'", node->as.call.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||||
return;
|
||||
}
|
||||
if (node->as.call.arg_count != sym->param_count) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"函数 '%s' 需要 %zu 个参数,但提供了 %zu 个",
|
||||
node->as.call.name, sym->param_count, node->as.call.arg_count);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||||
return;
|
||||
}
|
||||
if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) {
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||||
bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i],
|
||||
sym->param_struct_names ? sym->param_struct_names[i] : NULL,
|
||||
i, node, sym, errors, a);
|
||||
// 泛型单态化: 创建具象化函数副本并注册
|
||||
if (is_generic_param && sym->type_params && sym->type_param_count > 0) {
|
||||
TypeKind concrete = node->as.call.args[i]->type.kind;
|
||||
const char* concrete_sn = node->as.call.args[i]->type.struct_name;
|
||||
// 构造 mangled 名: fn$concrete_type
|
||||
const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete);
|
||||
int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1;
|
||||
char* mname = arena_alloc_impl(a, mname_len);
|
||||
snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name);
|
||||
// 检查是否已存在
|
||||
Symbol* existing = scope_lookup(scope, mname);
|
||||
if (!existing && g_program) {
|
||||
// 查找原始泛型函数 AST 节点
|
||||
AstNode* generic_fn = NULL;
|
||||
for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) {
|
||||
if (strcmp(g_program->as.program.functions[fn_i]->as.function.name,
|
||||
node->as.call.name) == 0) {
|
||||
generic_fn = g_program->as.program.functions[fn_i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (generic_fn && mono_count < 256) {
|
||||
// 创建浅拷贝(共享 body,subst_ast_types 修改类型标注)
|
||||
AstNode* mono_fn = ast_make_function(a, mname,
|
||||
generic_fn->as.function.params,
|
||||
generic_fn->as.function.param_count,
|
||||
generic_fn->as.function.return_type,
|
||||
generic_fn->as.function.return_struct_type_name,
|
||||
generic_fn->as.function.body,
|
||||
false, NULL, 0,
|
||||
generic_fn->loc);
|
||||
// 类型替换: T → concrete
|
||||
subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn);
|
||||
// 注册到队列
|
||||
mono_queue[mono_count++] = mono_fn;
|
||||
// 注册符号(后续分析会处理函数体)
|
||||
TypeKind* mpts = mono_fn->as.function.param_count > 0
|
||||
? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL;
|
||||
for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) {
|
||||
mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type;
|
||||
}
|
||||
scope_insert_function(scope, a, mname,
|
||||
mono_fn->as.function.return_type,
|
||||
mono_fn->as.function.return_struct_type_name,
|
||||
mpts, NULL, NULL,
|
||||
mono_fn->as.function.param_count, NULL, 0);
|
||||
}
|
||||
}
|
||||
// 重定向调用到单态化函数
|
||||
node->as.call.name = mname;
|
||||
sym = scope_lookup(scope, mname);
|
||||
if (!sym) { node->type.kind = TYPE_ERROR; return; }
|
||||
}
|
||||
}
|
||||
node->type.kind = sym->return_type;
|
||||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||||
node->type.struct_name = sym->return_struct_type_name;
|
||||
}
|
||||
|
||||
static void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.field_access.object, scope, errors, a);
|
||||
AstNode* obj = node->as.field_access.object;
|
||||
if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
if (obj->type.kind != TYPE_STRUCT) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"类型 '%s' 不是结构体,不能访问字段 '%s'",
|
||||
type_name(obj->type.kind), node->as.field_access.field);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
const char* struct_name = obj->type.struct_name;
|
||||
if (!struct_name) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "无法确定结构体类型");
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
Symbol* struct_sym = scope_lookup_struct(scope, struct_name);
|
||||
if (!struct_sym) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的结构体 '%s'", struct_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
int fi = scope_struct_field_index(struct_sym, node->as.field_access.field);
|
||||
if (fi < 0) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
node->type.kind = struct_sym->struct_field_types[fi];
|
||||
node->as.field_access.field_index = fi;
|
||||
if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names &&
|
||||
struct_sym->struct_field_struct_names[fi])
|
||||
node->type.struct_name = struct_sym->struct_field_struct_names[fi];
|
||||
}
|
||||
|
||||
static void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
const char* resolved = node->as.struct_init.type_name;
|
||||
Symbol* struct_sym = scope_lookup_struct(scope, resolved);
|
||||
if (!struct_sym) {
|
||||
Symbol* alias_sym = scope_lookup(scope, resolved);
|
||||
if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) {
|
||||
resolved = alias_sym->struct_type_name;
|
||||
struct_sym = scope_lookup_struct(scope, resolved);
|
||||
node->as.struct_init.type_name = resolved;
|
||||
}
|
||||
}
|
||||
if (!struct_sym) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的结构体类型 '%s'", node->as.struct_init.type_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (node->as.struct_init.field_count != struct_sym->struct_field_count) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 有 %zu 个字段,但提供了 %zu 个",
|
||||
node->as.struct_init.type_name,
|
||||
struct_sym->struct_field_count, node->as.struct_init.field_count);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||||
const char* fname = node->as.struct_init.field_names[i];
|
||||
AstNode* fval = node->as.struct_init.field_values[i];
|
||||
analyze_expr(fval, scope, errors, a);
|
||||
int fi = scope_struct_field_index(struct_sym, fname);
|
||||
if (fi < 0) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname);
|
||||
node->type.kind = TYPE_ERROR; continue;
|
||||
}
|
||||
TypeKind expected = struct_sym->struct_field_types[fi];
|
||||
TypeKind actual = fval->type.kind;
|
||||
if (actual != TYPE_ERROR && actual != expected)
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
fname, type_name(expected), type_name(actual));
|
||||
}
|
||||
if (node->type.kind != TYPE_ERROR) {
|
||||
node->type.kind = TYPE_STRUCT;
|
||||
node->type.struct_name = resolved;
|
||||
}
|
||||
}
|
||||
|
||||
static void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
(void)a;
|
||||
Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name);
|
||||
if (!enum_sym || enum_sym->kind != SYM_ENUM) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的枚举 '%s'", node->as.enum_variant.enum_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name);
|
||||
if (vi < 0) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"枚举 '%s' 没有变体 '%s'",
|
||||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
node->as.enum_variant.variant_index = vi;
|
||||
// ADT: 检查 payload
|
||||
TypeKind expected_pt = TYPE_VOID;
|
||||
if (enum_sym->variant_payload_types)
|
||||
expected_pt = enum_sym->variant_payload_types[vi];
|
||||
if (node->as.enum_variant.payload) {
|
||||
if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"枚举变体 '%s::%s' 不接受 payload",
|
||||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
analyze_expr(node->as.enum_variant.payload, scope, errors, a);
|
||||
TypeKind actual = node->as.enum_variant.payload->type.kind;
|
||||
if (actual != TYPE_ERROR && actual != expected_pt) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
type_name(expected_pt), type_name(actual));
|
||||
}
|
||||
}
|
||||
node->type.kind = TYPE_ENUM;
|
||||
}
|
||||
|
||||
static void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.index_expr.array, scope, errors, a);
|
||||
analyze_expr(node->as.index_expr.index, scope, errors, a);
|
||||
AstNode* arr = node->as.index_expr.array;
|
||||
AstNode* idx = node->as.index_expr.index;
|
||||
if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
if (arr->type.kind != TYPE_ARRAY) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"类型 '%s' 不支持索引操作", type_name(arr->type.kind));
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
if (idx->type.kind != TYPE_I64) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
node->type.kind = arr->type.element_type;
|
||||
node->type.struct_name = arr->type.element_struct_name;
|
||||
}
|
||||
|
||||
static void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.method_call.receiver, scope, errors, a);
|
||||
const char* recv_struct = node->as.method_call.receiver->type.struct_name;
|
||||
if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"只有结构体类型支持方法调用");
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
char mangled[256];
|
||||
snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct,
|
||||
node->as.method_call.method_name);
|
||||
Symbol* sym = scope_lookup(scope, mangled);
|
||||
// trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾且以 StructName 开头的符号
|
||||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||||
char suffix[256];
|
||||
snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name);
|
||||
size_t suf_len = strlen(suffix);
|
||||
size_t recv_len = strlen(recv_struct);
|
||||
for (const Scope* sc = scope; sc; sc = sc->parent) {
|
||||
for (Symbol* s = sc->head; s; s = s->next) {
|
||||
if (s->kind == SYM_FUNCTION) {
|
||||
size_t name_len = strlen(s->name);
|
||||
if (name_len > suf_len + recv_len
|
||||
&& strncmp(s->name, recv_struct, recv_len) == 0
|
||||
&& strcmp(s->name + name_len - suf_len, suffix) == 0) {
|
||||
sym = s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sym) break;
|
||||
}
|
||||
}
|
||||
// 更新 method_name 为符号的实际名称(codegen 需要通过它找到 LLVM 函数)
|
||||
if (sym && sym->kind == SYM_FUNCTION) {
|
||||
node->as.method_call.method_name = sym->name;
|
||||
}
|
||||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 没有方法 '%s'", recv_struct,
|
||||
node->as.method_call.method_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (node->as.method_call.arg_count + 1 != sym->param_count) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"方法 '%s' 需要 %zu 个参数,提供了 %zu 个",
|
||||
node->as.method_call.method_name,
|
||||
sym->param_count > 0 ? sym->param_count - 1 : 0,
|
||||
node->as.method_call.arg_count);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) {
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||||
analyze_expr(node->as.method_call.args[i], scope, errors, a);
|
||||
check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1],
|
||||
sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL,
|
||||
i, node, sym, errors, a);
|
||||
}
|
||||
node->type.kind = sym->return_type;
|
||||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||||
node->type.struct_name = sym->return_struct_type_name;
|
||||
}
|
||||
|
||||
// === 表达式类型检查(调度器) ===
|
||||
static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
switch (node->kind) {
|
||||
case AST_LITERAL_EXPR: break;
|
||||
case AST_IDENT_EXPR: analyze_ident_expr(node, scope, errors, a); break;
|
||||
case AST_UNARY_EXPR: analyze_unary_expr(node, scope, errors, a); break;
|
||||
case AST_BINARY_EXPR: analyze_binary_expr(node, scope, errors, a); break;
|
||||
case AST_CALL_EXPR: analyze_call_expr(node, scope, errors, a); break;
|
||||
case AST_FIELD_ACCESS: analyze_field_access(node, scope, errors, a); break;
|
||||
case AST_STRUCT_INIT: analyze_struct_init(node, scope, errors, a); break;
|
||||
case AST_ENUM_VARIANT: analyze_enum_variant(node, scope, errors, a); break;
|
||||
case AST_INDEX_EXPR: analyze_index_expr(node, scope, errors, a); break;
|
||||
case AST_METHOD_CALL: analyze_method_call(node, scope, errors, a); break;
|
||||
case AST_IF_STMT: analyze_node(node, scope, errors, a); break;
|
||||
case AST_BLOCK: analyze_node(node, scope, errors, a); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
if (!node) return;
|
||||
|
||||
switch (node->kind) {
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
#ifndef SEMA_INTERNAL_H
|
||||
#define SEMA_INTERNAL_H
|
||||
|
||||
#include "sema.h"
|
||||
#include "symbol.h"
|
||||
#include "ast.h"
|
||||
#include "error.h"
|
||||
#include "arena.h"
|
||||
#include "l_lang.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
// === 泛型单态化队列 ===
|
||||
extern AstNode* mono_queue[256];
|
||||
extern size_t mono_count;
|
||||
extern Arena* mono_arena;
|
||||
extern AstNode* g_program;
|
||||
|
||||
// === 类型推断上下文 ===
|
||||
extern TypeKind current_return_type;
|
||||
extern const char* current_return_struct_name;
|
||||
|
||||
// === 类型关系 ===
|
||||
TypeKind promote(TypeKind a, TypeKind b);
|
||||
bool is_numeric(TypeKind t);
|
||||
bool can_implicit_convert(TypeKind from, TypeKind to);
|
||||
bool is_comparable(TypeKind a, TypeKind b);
|
||||
|
||||
// === 泛型单态化 ===
|
||||
void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname);
|
||||
void subst_type_info(TypeInfo* ti, const char* tparam, TypeKind concrete, const char* concrete_sname);
|
||||
|
||||
// === 表达式类型检查 ===
|
||||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname,
|
||||
size_t idx, AstNode* call_node, Symbol* fn_sym,
|
||||
ErrorList* errors, Arena* a);
|
||||
bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
|
||||
ErrorList* errors, const char* call_name);
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,629 @@
|
||||
#include "sema_internal.h"
|
||||
|
||||
// === 泛型单态化: AST 类型替换 ===
|
||||
// 将 AST 中所有匹配 type_param_name 的类型引用替换为 concrete_type
|
||||
void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname);
|
||||
|
||||
void subst_type_info(TypeInfo* ti, const char* tparam, TypeKind concrete, const char* concrete_sname) {
|
||||
if (ti->kind == TYPE_STRUCT && ti->struct_name && strcmp(ti->struct_name, tparam) == 0) {
|
||||
ti->kind = concrete;
|
||||
ti->struct_name = concrete_sname;
|
||||
}
|
||||
}
|
||||
|
||||
void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname) {
|
||||
if (!node) return;
|
||||
// 替换节点自身类型
|
||||
subst_type_info(&node->type, tparam, concrete, concrete_sname);
|
||||
switch (node->kind) {
|
||||
case AST_PROGRAM:
|
||||
for (size_t i = 0; i < node->as.program.fn_count; i++)
|
||||
subst_ast_types(node->as.program.functions[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_FUNCTION:
|
||||
if (node->as.function.return_type == TYPE_STRUCT
|
||||
&& node->as.function.return_struct_type_name
|
||||
&& strcmp(node->as.function.return_struct_type_name, tparam) == 0) {
|
||||
node->as.function.return_type = concrete;
|
||||
node->as.function.return_struct_type_name = concrete_sname;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.function.param_count; i++)
|
||||
subst_ast_types(node->as.function.params[i], tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.function.body, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_PARAMETER:
|
||||
if (node->as.parameter.type == TYPE_STRUCT && node->as.parameter.struct_type_name
|
||||
&& strcmp(node->as.parameter.struct_type_name, tparam) == 0) {
|
||||
node->as.parameter.type = concrete;
|
||||
node->as.parameter.struct_type_name = concrete_sname;
|
||||
}
|
||||
break;
|
||||
case AST_BLOCK:
|
||||
for (size_t i = 0; i < node->as.block.stmt_count; i++)
|
||||
subst_ast_types(node->as.block.stmts[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_LET_STMT:
|
||||
if (node->as.let_stmt.annot_type == TYPE_STRUCT && node->as.let_stmt.struct_type_name
|
||||
&& strcmp(node->as.let_stmt.struct_type_name, tparam) == 0) {
|
||||
node->as.let_stmt.annot_type = concrete;
|
||||
node->as.let_stmt.struct_type_name = concrete_sname;
|
||||
}
|
||||
subst_ast_types(node->as.let_stmt.init, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_IF_STMT:
|
||||
subst_ast_types(node->as.if_stmt.cond, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.if_stmt.then_block, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.if_stmt.else_block, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_WHILE_STMT:
|
||||
subst_ast_types(node->as.while_stmt.cond, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.while_stmt.body, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_RETURN_STMT:
|
||||
subst_ast_types(node->as.return_stmt.expr, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_EXPR_STMT:
|
||||
subst_ast_types(node->as.expr_stmt.expr, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_BINARY_EXPR:
|
||||
subst_ast_types(node->as.binary.left, tparam, concrete, concrete_sname);
|
||||
subst_ast_types(node->as.binary.right, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_UNARY_EXPR:
|
||||
subst_ast_types(node->as.unary.operand, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_CALL_EXPR:
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||||
subst_ast_types(node->as.call.args[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_ASSIGN_STMT:
|
||||
subst_ast_types(node->as.assign_stmt.value, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_FIELD_ACCESS:
|
||||
subst_ast_types(node->as.field_access.object, tparam, concrete, concrete_sname);
|
||||
break;
|
||||
case AST_STRUCT_INIT:
|
||||
for (size_t i = 0; i < node->as.struct_init.field_count; i++)
|
||||
subst_ast_types(node->as.struct_init.field_values[i], tparam, concrete, concrete_sname);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// === 类型关系 ===
|
||||
TypeKind current_return_type = TYPE_VOID;
|
||||
const char* current_return_struct_name = NULL;
|
||||
|
||||
// 单态化队列: 泛型函数调用时生成的具象化函数
|
||||
AstNode* mono_queue[256];
|
||||
size_t mono_count = 0;
|
||||
Arena* mono_arena = NULL;
|
||||
AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板)
|
||||
|
||||
TypeKind promote(TypeKind a, TypeKind b) {
|
||||
// 枚举在算术运算中视为 i64
|
||||
if (a == TYPE_ENUM) a = TYPE_I64;
|
||||
if (b == TYPE_ENUM) b = TYPE_I64;
|
||||
// char 在算术中提升为 i32
|
||||
if (a == TYPE_CHAR) a = TYPE_I32;
|
||||
if (b == TYPE_CHAR) b = TYPE_I32;
|
||||
if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64;
|
||||
if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64;
|
||||
if (a == TYPE_U64 || b == TYPE_U64) return TYPE_U64;
|
||||
if (a == TYPE_I32 || b == TYPE_I32) return TYPE_I32;
|
||||
if (a == TYPE_BOOL || b == TYPE_BOOL) return TYPE_BOOL;
|
||||
return TYPE_ERROR;
|
||||
}
|
||||
|
||||
bool is_numeric(TypeKind t) {
|
||||
return t == TYPE_I32 || t == TYPE_I64 || t == TYPE_U64
|
||||
|| t == TYPE_F64 || t == TYPE_CHAR || t == TYPE_ENUM;
|
||||
}
|
||||
// 隐式类型转换规则: 无损加宽允许,有符号→无符号不允许
|
||||
bool can_implicit_convert(TypeKind from, TypeKind to) {
|
||||
if (from == to) return true;
|
||||
// 枚举视为 i64
|
||||
if (from == TYPE_ENUM) from = TYPE_I64;
|
||||
if (to == TYPE_ENUM) to = TYPE_I64;
|
||||
// char 可转为任意整数
|
||||
if (from == TYPE_CHAR) return to == TYPE_I32 || to == TYPE_I64 || to == TYPE_U64 || to == TYPE_F64;
|
||||
// i32 可加宽
|
||||
if (from == TYPE_I32) return to == TYPE_I64 || to == TYPE_F64;
|
||||
// i64 可转 f64
|
||||
if (from == TYPE_I64) return to == TYPE_F64;
|
||||
// u64 ↔ i64 双向允许(同一位宽,LLVM 同类型)
|
||||
if (from == TYPE_U64) return to == TYPE_F64 || to == TYPE_I64;
|
||||
if (from == TYPE_I64) return to == TYPE_F64 || to == TYPE_U64;
|
||||
return false;
|
||||
}
|
||||
bool is_comparable(TypeKind a, TypeKind b) {
|
||||
if (a == b) return true;
|
||||
// 枚举可以参与整数比较
|
||||
if ((a == TYPE_I64 && b == TYPE_ENUM) || (a == TYPE_ENUM && b == TYPE_I64)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// === 向前声明 ===
|
||||
void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
|
||||
// === 表达式类型检查辅助函数 ===
|
||||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||||
|
||||
void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
(void)a;
|
||||
Symbol* sym = scope_lookup(scope, node->as.ident.name);
|
||||
if (!sym) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的变量 '%s'", node->as.ident.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else if (sym->is_type_alias) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'%s' 是类型别名,不能作为表达式使用", node->as.ident.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else if (sym->kind == SYM_FUNCTION) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'%s' 是函数,不能作为表达式使用", node->as.ident.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = sym->type;
|
||||
if (sym->type == TYPE_STRUCT && sym->struct_type_name)
|
||||
node->type.struct_name = sym->struct_type_name;
|
||||
if (sym->type == TYPE_ARRAY) {
|
||||
node->type.element_type = sym->array_element_type;
|
||||
node->type.element_struct_name = sym->array_element_struct_name;
|
||||
node->type.array_size = sym->array_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.unary.operand, scope, errors, a);
|
||||
TypeKind inner = node->as.unary.operand->type.kind;
|
||||
if (node->as.unary.op == OP_NEG) {
|
||||
if (!is_numeric(inner)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"一元 '-' 只能用于数值类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = inner;
|
||||
}
|
||||
} else { // OP_NOT
|
||||
if (inner != TYPE_BOOL) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'!' 只能用于布尔类型,得到 '%s'", type_name(inner));
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_BOOL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.binary.left, scope, errors, a);
|
||||
analyze_expr(node->as.binary.right, scope, errors, a);
|
||||
TypeKind l = node->as.binary.left->type.kind;
|
||||
TypeKind r = node->as.binary.right->type.kind;
|
||||
if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
|
||||
switch (node->as.binary.op) {
|
||||
case OP_ADD:
|
||||
if (l == TYPE_STR || r == TYPE_STR) {
|
||||
if (l != TYPE_STR || r != TYPE_STR) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'",
|
||||
type_name(l), type_name(r));
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_STR;
|
||||
}
|
||||
} else if (!is_numeric(l) || !is_numeric(r)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = promote(l, r);
|
||||
}
|
||||
break;
|
||||
case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD:
|
||||
if (!is_numeric(l) || !is_numeric(r)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = promote(l, r);
|
||||
}
|
||||
break;
|
||||
case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE:
|
||||
if (!is_comparable(l, r)) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r));
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_BOOL;
|
||||
}
|
||||
break;
|
||||
case OP_AND: case OP_OR:
|
||||
if (l != TYPE_BOOL || r != TYPE_BOOL) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "逻辑运算需要布尔类型");
|
||||
node->type.kind = TYPE_ERROR;
|
||||
} else {
|
||||
node->type.kind = TYPE_BOOL;
|
||||
}
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用)
|
||||
bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname,
|
||||
size_t idx, AstNode* call_node, Symbol* fn_sym,
|
||||
ErrorList* errors, Arena* a) {
|
||||
(void)a;
|
||||
TypeKind actual = arg->type.kind;
|
||||
if (actual == TYPE_ERROR) return false;
|
||||
if (expected == TYPE_STRUCT && expected_sname) {
|
||||
// 检查是否是泛型类型参数(匹配则接受任意类型)
|
||||
if (fn_sym && fn_sym->type_params) {
|
||||
for (size_t t = 0; t < fn_sym->type_param_count; t++) {
|
||||
if (strcmp(expected_sname, fn_sym->type_params[t]) == 0)
|
||||
return true; // 泛型参数,接受任意类型
|
||||
}
|
||||
}
|
||||
const char* actual_name = arg->type.struct_name;
|
||||
if (actual != TYPE_STRUCT || !actual_name ||
|
||||
strcmp(actual_name, expected_sname) != 0) {
|
||||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
idx + 1, expected_sname ? expected_sname : "struct",
|
||||
actual_name ? actual_name : type_name(actual));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (actual == expected) return false;
|
||||
if (expected == TYPE_I64 && actual == TYPE_ENUM) return false;
|
||||
if (can_implicit_convert(actual, expected)) return false;
|
||||
if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR
|
||||
&& (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false;
|
||||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
idx + 1, type_name(expected), type_name(actual));
|
||||
return false;
|
||||
}
|
||||
|
||||
// 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用)
|
||||
bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
|
||||
ErrorList* errors, const char* call_name) {
|
||||
AstNode** args = node->as.call.args;
|
||||
const char** arg_names = node->as.call.arg_names;
|
||||
size_t arg_count = node->as.call.arg_count;
|
||||
if (!arg_names) return true;
|
||||
AstNode* reordered[16] = {0};
|
||||
for (size_t i = 0; i < arg_count; i++) {
|
||||
if (arg_names[i]) {
|
||||
bool found = false;
|
||||
for (size_t j = param_offset; j < sym->param_count; j++) {
|
||||
if (sym->param_names && sym->param_names[j] &&
|
||||
strcmp(arg_names[i], sym->param_names[j]) == 0) {
|
||||
reordered[j - param_offset] = args[i];
|
||||
found = true; break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
reordered[i] = args[i];
|
||||
}
|
||||
}
|
||||
memcpy(args, reordered, arg_count * sizeof(AstNode*));
|
||||
return true;
|
||||
}
|
||||
|
||||
void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
Symbol* sym = scope_lookup(scope, node->as.call.name);
|
||||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的函数 '%s'", node->as.call.name);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||||
return;
|
||||
}
|
||||
if (node->as.call.arg_count != sym->param_count) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"函数 '%s' 需要 %zu 个参数,但提供了 %zu 个",
|
||||
node->as.call.name, sym->param_count, node->as.call.arg_count);
|
||||
node->type.kind = TYPE_ERROR;
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||||
return;
|
||||
}
|
||||
if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) {
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||||
bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i],
|
||||
sym->param_struct_names ? sym->param_struct_names[i] : NULL,
|
||||
i, node, sym, errors, a);
|
||||
// 泛型单态化: 创建具象化函数副本并注册
|
||||
if (is_generic_param && sym->type_params && sym->type_param_count > 0) {
|
||||
TypeKind concrete = node->as.call.args[i]->type.kind;
|
||||
const char* concrete_sn = node->as.call.args[i]->type.struct_name;
|
||||
// 构造 mangled 名: fn$concrete_type
|
||||
const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete);
|
||||
int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1;
|
||||
char* mname = arena_alloc_impl(a, mname_len);
|
||||
snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name);
|
||||
// 检查是否已存在
|
||||
Symbol* existing = scope_lookup(scope, mname);
|
||||
if (!existing && g_program) {
|
||||
// 查找原始泛型函数 AST 节点
|
||||
AstNode* generic_fn = NULL;
|
||||
for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) {
|
||||
if (strcmp(g_program->as.program.functions[fn_i]->as.function.name,
|
||||
node->as.call.name) == 0) {
|
||||
generic_fn = g_program->as.program.functions[fn_i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (generic_fn && mono_count < 256) {
|
||||
// 创建浅拷贝(共享 body,subst_ast_types 修改类型标注)
|
||||
AstNode* mono_fn = ast_make_function(a, mname,
|
||||
generic_fn->as.function.params,
|
||||
generic_fn->as.function.param_count,
|
||||
generic_fn->as.function.return_type,
|
||||
generic_fn->as.function.return_struct_type_name,
|
||||
generic_fn->as.function.body,
|
||||
false, NULL, 0,
|
||||
generic_fn->loc);
|
||||
// 类型替换: T → concrete
|
||||
subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn);
|
||||
// 注册到队列
|
||||
mono_queue[mono_count++] = mono_fn;
|
||||
// 注册符号(后续分析会处理函数体)
|
||||
TypeKind* mpts = mono_fn->as.function.param_count > 0
|
||||
? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL;
|
||||
for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) {
|
||||
mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type;
|
||||
}
|
||||
scope_insert_function(scope, a, mname,
|
||||
mono_fn->as.function.return_type,
|
||||
mono_fn->as.function.return_struct_type_name,
|
||||
mpts, NULL, NULL,
|
||||
mono_fn->as.function.param_count, NULL, 0);
|
||||
}
|
||||
}
|
||||
// 重定向调用到单态化函数
|
||||
node->as.call.name = mname;
|
||||
sym = scope_lookup(scope, mname);
|
||||
if (!sym) { node->type.kind = TYPE_ERROR; return; }
|
||||
}
|
||||
}
|
||||
node->type.kind = sym->return_type;
|
||||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||||
node->type.struct_name = sym->return_struct_type_name;
|
||||
}
|
||||
|
||||
void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.field_access.object, scope, errors, a);
|
||||
AstNode* obj = node->as.field_access.object;
|
||||
if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
if (obj->type.kind != TYPE_STRUCT) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"类型 '%s' 不是结构体,不能访问字段 '%s'",
|
||||
type_name(obj->type.kind), node->as.field_access.field);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
const char* struct_name = obj->type.struct_name;
|
||||
if (!struct_name) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "无法确定结构体类型");
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
Symbol* struct_sym = scope_lookup_struct(scope, struct_name);
|
||||
if (!struct_sym) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的结构体 '%s'", struct_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
int fi = scope_struct_field_index(struct_sym, node->as.field_access.field);
|
||||
if (fi < 0) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
node->type.kind = struct_sym->struct_field_types[fi];
|
||||
node->as.field_access.field_index = fi;
|
||||
if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names &&
|
||||
struct_sym->struct_field_struct_names[fi])
|
||||
node->type.struct_name = struct_sym->struct_field_struct_names[fi];
|
||||
}
|
||||
|
||||
void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
const char* resolved = node->as.struct_init.type_name;
|
||||
Symbol* struct_sym = scope_lookup_struct(scope, resolved);
|
||||
if (!struct_sym) {
|
||||
Symbol* alias_sym = scope_lookup(scope, resolved);
|
||||
if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) {
|
||||
resolved = alias_sym->struct_type_name;
|
||||
struct_sym = scope_lookup_struct(scope, resolved);
|
||||
node->as.struct_init.type_name = resolved;
|
||||
}
|
||||
}
|
||||
if (!struct_sym) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的结构体类型 '%s'", node->as.struct_init.type_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (node->as.struct_init.field_count != struct_sym->struct_field_count) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 有 %zu 个字段,但提供了 %zu 个",
|
||||
node->as.struct_init.type_name,
|
||||
struct_sym->struct_field_count, node->as.struct_init.field_count);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||||
const char* fname = node->as.struct_init.field_names[i];
|
||||
AstNode* fval = node->as.struct_init.field_values[i];
|
||||
analyze_expr(fval, scope, errors, a);
|
||||
int fi = scope_struct_field_index(struct_sym, fname);
|
||||
if (fi < 0) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname);
|
||||
node->type.kind = TYPE_ERROR; continue;
|
||||
}
|
||||
TypeKind expected = struct_sym->struct_field_types[fi];
|
||||
TypeKind actual = fval->type.kind;
|
||||
if (actual != TYPE_ERROR && actual != expected)
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
fname, type_name(expected), type_name(actual));
|
||||
}
|
||||
if (node->type.kind != TYPE_ERROR) {
|
||||
node->type.kind = TYPE_STRUCT;
|
||||
node->type.struct_name = resolved;
|
||||
}
|
||||
}
|
||||
|
||||
void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
(void)a;
|
||||
Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name);
|
||||
if (!enum_sym || enum_sym->kind != SYM_ENUM) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"未定义的枚举 '%s'", node->as.enum_variant.enum_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name);
|
||||
if (vi < 0) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"枚举 '%s' 没有变体 '%s'",
|
||||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
node->as.enum_variant.variant_index = vi;
|
||||
// ADT: 检查 payload
|
||||
TypeKind expected_pt = TYPE_VOID;
|
||||
if (enum_sym->variant_payload_types)
|
||||
expected_pt = enum_sym->variant_payload_types[vi];
|
||||
if (node->as.enum_variant.payload) {
|
||||
if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"枚举变体 '%s::%s' 不接受 payload",
|
||||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
analyze_expr(node->as.enum_variant.payload, scope, errors, a);
|
||||
TypeKind actual = node->as.enum_variant.payload->type.kind;
|
||||
if (actual != TYPE_ERROR && actual != expected_pt) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'",
|
||||
type_name(expected_pt), type_name(actual));
|
||||
}
|
||||
}
|
||||
node->type.kind = TYPE_ENUM;
|
||||
}
|
||||
|
||||
void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.index_expr.array, scope, errors, a);
|
||||
analyze_expr(node->as.index_expr.index, scope, errors, a);
|
||||
AstNode* arr = node->as.index_expr.array;
|
||||
AstNode* idx = node->as.index_expr.index;
|
||||
if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
if (arr->type.kind != TYPE_ARRAY) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"类型 '%s' 不支持索引操作", type_name(arr->type.kind));
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||||
if (idx->type.kind != TYPE_I64) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
node->type.kind = arr->type.element_type;
|
||||
node->type.struct_name = arr->type.element_struct_name;
|
||||
}
|
||||
|
||||
void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
analyze_expr(node->as.method_call.receiver, scope, errors, a);
|
||||
const char* recv_struct = node->as.method_call.receiver->type.struct_name;
|
||||
if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"只有结构体类型支持方法调用");
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
char mangled[256];
|
||||
snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct,
|
||||
node->as.method_call.method_name);
|
||||
Symbol* sym = scope_lookup(scope, mangled);
|
||||
// trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾且以 StructName 开头的符号
|
||||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||||
char suffix[256];
|
||||
snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name);
|
||||
size_t suf_len = strlen(suffix);
|
||||
size_t recv_len = strlen(recv_struct);
|
||||
for (const Scope* sc = scope; sc; sc = sc->parent) {
|
||||
for (Symbol* s = sc->head; s; s = s->next) {
|
||||
if (s->kind == SYM_FUNCTION) {
|
||||
size_t name_len = strlen(s->name);
|
||||
if (name_len > suf_len + recv_len
|
||||
&& strncmp(s->name, recv_struct, recv_len) == 0
|
||||
&& strcmp(s->name + name_len - suf_len, suffix) == 0) {
|
||||
sym = s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sym) break;
|
||||
}
|
||||
}
|
||||
// 更新 method_name 为符号的实际名称(codegen 需要通过它找到 LLVM 函数)
|
||||
if (sym && sym->kind == SYM_FUNCTION) {
|
||||
node->as.method_call.method_name = sym->name;
|
||||
}
|
||||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"结构体 '%s' 没有方法 '%s'", recv_struct,
|
||||
node->as.method_call.method_name);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (node->as.method_call.arg_count + 1 != sym->param_count) {
|
||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||
"方法 '%s' 需要 %zu 个参数,提供了 %zu 个",
|
||||
node->as.method_call.method_name,
|
||||
sym->param_count > 0 ? sym->param_count - 1 : 0,
|
||||
node->as.method_call.arg_count);
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) {
|
||||
node->type.kind = TYPE_ERROR; return;
|
||||
}
|
||||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||||
analyze_expr(node->as.method_call.args[i], scope, errors, a);
|
||||
check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1],
|
||||
sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL,
|
||||
i, node, sym, errors, a);
|
||||
}
|
||||
node->type.kind = sym->return_type;
|
||||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||||
node->type.struct_name = sym->return_struct_type_name;
|
||||
}
|
||||
|
||||
// === 表达式类型检查(调度器) ===
|
||||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||
switch (node->kind) {
|
||||
case AST_LITERAL_EXPR: break;
|
||||
case AST_IDENT_EXPR: analyze_ident_expr(node, scope, errors, a); break;
|
||||
case AST_UNARY_EXPR: analyze_unary_expr(node, scope, errors, a); break;
|
||||
case AST_BINARY_EXPR: analyze_binary_expr(node, scope, errors, a); break;
|
||||
case AST_CALL_EXPR: analyze_call_expr(node, scope, errors, a); break;
|
||||
case AST_FIELD_ACCESS: analyze_field_access(node, scope, errors, a); break;
|
||||
case AST_STRUCT_INIT: analyze_struct_init(node, scope, errors, a); break;
|
||||
case AST_ENUM_VARIANT: analyze_enum_variant(node, scope, errors, a); break;
|
||||
case AST_INDEX_EXPR: analyze_index_expr(node, scope, errors, a); break;
|
||||
case AST_METHOD_CALL: analyze_method_call(node, scope, errors, a); break;
|
||||
case AST_IF_STMT: analyze_node(node, scope, errors, a); break;
|
||||
case AST_BLOCK: analyze_node(node, scope, errors, a); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user