#include "sema_internal.h" // === 泛型单态化: AST 类型替换 === // === 类型关系 === TypeKind current_return_type = TYPE_VOID; const char* current_return_struct_name = NULL; // === 类型关系(基于 TypeTable 数据驱动)=== TypeKind promote(TypeKind a, TypeKind b) { // 枚举在算术中视为 i64, char 视为 i32 if (a == TYPE_ENUM) a = TYPE_I64; if (b == TYPE_ENUM) b = TYPE_I64; if (a == TYPE_CHAR) a = TYPE_I32; if (b == TYPE_CHAR) b = TYPE_I32; const TypeDesc* ta = type_lookup(a); const TypeDesc* tb = type_lookup(b); if (!ta->is_numeric || !tb->is_numeric) return TYPE_ERROR; return ta->promote_rank >= tb->promote_rank ? a : b; } bool is_numeric(TypeKind t) { return type_lookup(t)->is_numeric; } // 隐式转换: 加宽允许, 同 bit_width 的有/无符号双向允许 (u64↔i64) bool can_implicit_convert(TypeKind from, TypeKind to) { if (from == to) return true; if (from == TYPE_ENUM) from = TYPE_I64; if (to == TYPE_ENUM) to = TYPE_I64; if (from == to) return true; const TypeDesc* tf = type_lookup(from); const TypeDesc* tt = type_lookup(to); if (!tf->is_numeric || !tt->is_numeric) return false; // 同 bit_width: 有/无符号整数双向允许 (u64↔i64, 在 LLVM 中同为 64-bit) if (tf->bit_width == tt->bit_width && tf->bit_width >= 32) return true; // 加宽转换: 有符号→任意, 无符号→仅无符号 if (tt->promote_rank > tf->promote_rank) return tf->is_signed || !tt->is_signed; return false; } bool is_comparable(TypeKind a, TypeKind b) { if (a == b) return true; if (a == TYPE_ENUM) a = TYPE_I64; if (b == TYPE_ENUM) b = TYPE_I64; return type_lookup(a)->is_numeric && type_lookup(b)->is_numeric; } // === 向前声明 === void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); // === 表达式类型检查辅助函数 === void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { (void)a; Symbol* sym = scope_lookup(scope, node->as.ident.name); if (!sym) { error_add(errors, "", node->loc.line, node->loc.col, "未定义的变量 '%s'", node->as.ident.name); node->type.kind = TYPE_ERROR; } else if (sym->is_type_alias) { error_add(errors, "", node->loc.line, node->loc.col, "'%s' 是类型别名,不能作为表达式使用", node->as.ident.name); node->type.kind = TYPE_ERROR; } else if (sym->kind == SYM_FUNCTION) { error_add(errors, "", node->loc.line, node->loc.col, "'%s' 是函数,不能作为表达式使用", node->as.ident.name); node->type.kind = TYPE_ERROR; } else { node->type.kind = sym->type; if (sym->type == TYPE_STRUCT && sym->struct_type_name) node->type.struct_name = sym->struct_type_name; if (sym->type == TYPE_ARRAY) { node->type.element_type = sym->array_element_type; node->type.element_struct_name = sym->array_element_struct_name; node->type.array_size = sym->array_size; } } } void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { analyze_expr(node->as.unary.operand, scope, errors, a); TypeKind inner = node->as.unary.operand->type.kind; if (node->as.unary.op == OP_NEG) { if (!is_numeric(inner)) { error_add(errors, "", node->loc.line, node->loc.col, "一元 '-' 只能用于数值类型"); node->type.kind = TYPE_ERROR; } else { node->type.kind = inner; } } else { // OP_NOT if (inner != TYPE_BOOL) { error_add(errors, "", node->loc.line, node->loc.col, "'!' 只能用于布尔类型,得到 '%s'", type_name(inner)); node->type.kind = TYPE_ERROR; } else { node->type.kind = TYPE_BOOL; } } } void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { analyze_expr(node->as.binary.left, scope, errors, a); analyze_expr(node->as.binary.right, scope, errors, a); TypeKind l = node->as.binary.left->type.kind; TypeKind r = node->as.binary.right->type.kind; if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } switch (node->as.binary.op) { case OP_ADD: if (l == TYPE_STR || r == TYPE_STR) { if (l != TYPE_STR || r != TYPE_STR) { error_add(errors, "", node->loc.line, node->loc.col, "字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'", type_name(l), type_name(r)); node->type.kind = TYPE_ERROR; } else { node->type.kind = TYPE_STR; } } else if (!is_numeric(l) || !is_numeric(r)) { error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); node->type.kind = TYPE_ERROR; } else { node->type.kind = promote(l, r); } break; case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD: if (!is_numeric(l) || !is_numeric(r)) { error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); node->type.kind = TYPE_ERROR; } else { node->type.kind = promote(l, r); } break; case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE: if (!is_comparable(l, r)) { error_add(errors, "", node->loc.line, node->loc.col, "类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r)); node->type.kind = TYPE_ERROR; } else { node->type.kind = TYPE_BOOL; } break; case OP_AND: case OP_OR: if (l != TYPE_BOOL || r != TYPE_BOOL) { error_add(errors, "", node->loc.line, node->loc.col, "逻辑运算需要布尔类型"); node->type.kind = TYPE_ERROR; } else { node->type.kind = TYPE_BOOL; } break; default: break; } } // 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用) bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname, size_t idx, AstNode* call_node, Symbol* fn_sym, ErrorList* errors, Arena* a) { (void)a; TypeKind actual = arg->type.kind; if (actual == TYPE_ERROR) return false; if (expected == TYPE_STRUCT && expected_sname) { // 检查是否是泛型类型参数(匹配则接受任意类型) if (fn_sym && fn_sym->type_params) { for (size_t t = 0; t < fn_sym->type_param_count; t++) { if (strcmp(expected_sname, fn_sym->type_params[t]) == 0) return true; // 泛型参数,接受任意类型 } } const char* actual_name = arg->type.struct_name; if (actual != TYPE_STRUCT || !actual_name || strcmp(actual_name, expected_sname) != 0) { error_add(errors, "", call_node->loc.line, call_node->loc.col, "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", idx + 1, expected_sname ? expected_sname : "struct", actual_name ? actual_name : type_name(actual)); } return false; } if (actual == expected) return false; if (expected == TYPE_I64 && actual == TYPE_ENUM) return false; if (can_implicit_convert(actual, expected)) return false; if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR && (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false; error_add(errors, "", call_node->loc.line, call_node->loc.col, "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", idx + 1, type_name(expected), type_name(actual)); return false; } // 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用) bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset, ErrorList* errors, const char* call_name) { AstNode** args = node->as.call.args; const char** arg_names = node->as.call.arg_names; size_t arg_count = node->as.call.arg_count; if (!arg_names) return true; AstNode* reordered[16] = {0}; for (size_t i = 0; i < arg_count; i++) { if (arg_names[i]) { bool found = false; for (size_t j = param_offset; j < sym->param_count; j++) { if (sym->param_names && sym->param_names[j] && strcmp(arg_names[i], sym->param_names[j]) == 0) { reordered[j - param_offset] = args[i]; found = true; break; } } if (!found) { error_add(errors, "", node->loc.line, node->loc.col, "'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]); return false; } } else { reordered[i] = args[i]; } } memcpy(args, reordered, arg_count * sizeof(AstNode*)); return true; } void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { Symbol* sym = scope_lookup(scope, node->as.call.name); if (!sym || sym->kind != SYM_FUNCTION) { error_add(errors, "", node->loc.line, node->loc.col, "未定义的函数 '%s'", node->as.call.name); node->type.kind = TYPE_ERROR; for (size_t i = 0; i < node->as.call.arg_count; i++) analyze_expr(node->as.call.args[i], scope, errors, a); return; } if (node->as.call.arg_count != sym->param_count) { error_add(errors, "", node->loc.line, node->loc.col, "函数 '%s' 需要 %zu 个参数,但提供了 %zu 个", node->as.call.name, sym->param_count, node->as.call.arg_count); node->type.kind = TYPE_ERROR; for (size_t i = 0; i < node->as.call.arg_count; i++) analyze_expr(node->as.call.args[i], scope, errors, a); return; } if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) { node->type.kind = TYPE_ERROR; return; } for (size_t i = 0; i < node->as.call.arg_count; i++) { analyze_expr(node->as.call.args[i], scope, errors, a); bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i], sym->param_struct_names ? sym->param_struct_names[i] : NULL, i, node, sym, errors, a); // 泛型单态化: 创建具象化函数副本并注册 if (is_generic_param && sym->type_params && sym->type_param_count > 0) { TypeKind concrete = node->as.call.args[i]->type.kind; const char* concrete_sn = node->as.call.args[i]->type.struct_name; // 构造 mangled 名: fn$concrete_type const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete); int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1; char* mname = arena_alloc_impl(a, mname_len); snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name); // 检查是否已存在 Symbol* existing = scope_lookup(scope, mname); if (!existing && g_program) { // 查找原始泛型函数 AST 节点 AstNode* generic_fn = NULL; for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) { if (strcmp(g_program->as.program.functions[fn_i]->as.function.name, node->as.call.name) == 0) { generic_fn = g_program->as.program.functions[fn_i]; break; } } if (generic_fn && mono_count < 256) { // 创建浅拷贝(共享 body,subst_ast_types 修改类型标注) AstNode* mono_fn = ast_make_function(a, mname, generic_fn->as.function.params, generic_fn->as.function.param_count, generic_fn->as.function.return_type, generic_fn->as.function.return_struct_type_name, generic_fn->as.function.body, false, NULL, 0, generic_fn->loc); // 类型替换: T → concrete subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn); // 注册到队列 mono_queue[mono_count++] = mono_fn; // 注册符号(后续分析会处理函数体) TypeKind* mpts = mono_fn->as.function.param_count > 0 ? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL; for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) { mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type; } scope_insert_function(scope, a, mname, mono_fn->as.function.return_type, mono_fn->as.function.return_struct_type_name, mpts, NULL, NULL, mono_fn->as.function.param_count, NULL, 0); } } // 重定向调用到单态化函数 node->as.call.name = mname; sym = scope_lookup(scope, mname); if (!sym) { node->type.kind = TYPE_ERROR; return; } } } node->type.kind = sym->return_type; if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) node->type.struct_name = sym->return_struct_type_name; } void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { analyze_expr(node->as.field_access.object, scope, errors, a); AstNode* obj = node->as.field_access.object; if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } if (obj->type.kind != TYPE_STRUCT) { error_add(errors, "", node->loc.line, node->loc.col, "类型 '%s' 不是结构体,不能访问字段 '%s'", type_name(obj->type.kind), node->as.field_access.field); node->type.kind = TYPE_ERROR; return; } const char* struct_name = obj->type.struct_name; if (!struct_name) { error_add(errors, "", node->loc.line, node->loc.col, "无法确定结构体类型"); node->type.kind = TYPE_ERROR; return; } Symbol* struct_sym = scope_lookup_struct(scope, struct_name); if (!struct_sym) { error_add(errors, "", node->loc.line, node->loc.col, "未定义的结构体 '%s'", struct_name); node->type.kind = TYPE_ERROR; return; } int fi = scope_struct_field_index(struct_sym, node->as.field_access.field); if (fi < 0) { error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field); node->type.kind = TYPE_ERROR; return; } node->type.kind = struct_sym->struct_field_types[fi]; node->as.field_access.field_index = fi; if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names && struct_sym->struct_field_struct_names[fi]) node->type.struct_name = struct_sym->struct_field_struct_names[fi]; } void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { const char* resolved = node->as.struct_init.type_name; Symbol* struct_sym = scope_lookup_struct(scope, resolved); if (!struct_sym) { Symbol* alias_sym = scope_lookup(scope, resolved); if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) { resolved = alias_sym->struct_type_name; struct_sym = scope_lookup_struct(scope, resolved); node->as.struct_init.type_name = resolved; } } if (!struct_sym) { error_add(errors, "", node->loc.line, node->loc.col, "未定义的结构体类型 '%s'", node->as.struct_init.type_name); node->type.kind = TYPE_ERROR; return; } if (node->as.struct_init.field_count != struct_sym->struct_field_count) { error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 有 %zu 个字段,但提供了 %zu 个", node->as.struct_init.type_name, struct_sym->struct_field_count, node->as.struct_init.field_count); node->type.kind = TYPE_ERROR; return; } for (size_t i = 0; i < node->as.struct_init.field_count; i++) { const char* fname = node->as.struct_init.field_names[i]; AstNode* fval = node->as.struct_init.field_values[i]; analyze_expr(fval, scope, errors, a); int fi = scope_struct_field_index(struct_sym, fname); if (fi < 0) { error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname); node->type.kind = TYPE_ERROR; continue; } TypeKind expected = struct_sym->struct_field_types[fi]; TypeKind actual = fval->type.kind; if (actual != TYPE_ERROR && actual != expected) error_add(errors, "", node->loc.line, node->loc.col, "字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'", fname, type_name(expected), type_name(actual)); } if (node->type.kind != TYPE_ERROR) { node->type.kind = TYPE_STRUCT; node->type.struct_name = resolved; } } void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { (void)a; Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name); if (!enum_sym || enum_sym->kind != SYM_ENUM) { error_add(errors, "", node->loc.line, node->loc.col, "未定义的枚举 '%s'", node->as.enum_variant.enum_name); node->type.kind = TYPE_ERROR; return; } int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name); if (vi < 0) { error_add(errors, "", node->loc.line, node->loc.col, "枚举 '%s' 没有变体 '%s'", node->as.enum_variant.enum_name, node->as.enum_variant.variant_name); node->type.kind = TYPE_ERROR; return; } node->as.enum_variant.variant_index = vi; // ADT: 检查 payload TypeKind expected_pt = TYPE_VOID; if (enum_sym->variant_payload_types) expected_pt = enum_sym->variant_payload_types[vi]; if (node->as.enum_variant.payload) { if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) { error_add(errors, "", node->loc.line, node->loc.col, "枚举变体 '%s::%s' 不接受 payload", node->as.enum_variant.enum_name, node->as.enum_variant.variant_name); node->type.kind = TYPE_ERROR; return; } analyze_expr(node->as.enum_variant.payload, scope, errors, a); TypeKind actual = node->as.enum_variant.payload->type.kind; if (actual != TYPE_ERROR && actual != expected_pt) { error_add(errors, "", node->loc.line, node->loc.col, "枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'", type_name(expected_pt), type_name(actual)); } } node->type.kind = TYPE_ENUM; } void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { analyze_expr(node->as.index_expr.array, scope, errors, a); analyze_expr(node->as.index_expr.index, scope, errors, a); AstNode* arr = node->as.index_expr.array; AstNode* idx = node->as.index_expr.index; if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } if (arr->type.kind != TYPE_ARRAY) { error_add(errors, "", node->loc.line, node->loc.col, "类型 '%s' 不支持索引操作", type_name(arr->type.kind)); node->type.kind = TYPE_ERROR; return; } if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; } if (idx->type.kind != TYPE_I64) { error_add(errors, "", node->loc.line, node->loc.col, "数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind)); node->type.kind = TYPE_ERROR; return; } node->type.kind = arr->type.element_type; node->type.struct_name = arr->type.element_struct_name; } void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { analyze_expr(node->as.method_call.receiver, scope, errors, a); const char* recv_struct = node->as.method_call.receiver->type.struct_name; if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) { error_add(errors, "", node->loc.line, node->loc.col, "只有结构体类型支持方法调用"); node->type.kind = TYPE_ERROR; return; } char mangled[256]; snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct, node->as.method_call.method_name); Symbol* sym = scope_lookup(scope, mangled); // trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾且以 StructName 开头的符号 if (!sym || sym->kind != SYM_FUNCTION) { char suffix[256]; snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name); size_t suf_len = strlen(suffix); size_t recv_len = strlen(recv_struct); for (const Scope* sc = scope; sc; sc = sc->parent) { for (Symbol* s = sc->head; s; s = s->next) { if (s->kind == SYM_FUNCTION) { size_t name_len = strlen(s->name); if (name_len > suf_len + recv_len && strncmp(s->name, recv_struct, recv_len) == 0 && strcmp(s->name + name_len - suf_len, suffix) == 0) { sym = s; break; } } } if (sym) break; } } // 更新 method_name 为符号的实际名称(codegen 需要通过它找到 LLVM 函数) if (sym && sym->kind == SYM_FUNCTION) { node->as.method_call.method_name = sym->name; } if (!sym || sym->kind != SYM_FUNCTION) { error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 没有方法 '%s'", recv_struct, node->as.method_call.method_name); node->type.kind = TYPE_ERROR; return; } if (node->as.method_call.arg_count + 1 != sym->param_count) { error_add(errors, "", node->loc.line, node->loc.col, "方法 '%s' 需要 %zu 个参数,提供了 %zu 个", node->as.method_call.method_name, sym->param_count > 0 ? sym->param_count - 1 : 0, node->as.method_call.arg_count); node->type.kind = TYPE_ERROR; return; } if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) { node->type.kind = TYPE_ERROR; return; } for (size_t i = 0; i < node->as.method_call.arg_count; i++) { analyze_expr(node->as.method_call.args[i], scope, errors, a); check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1], sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL, i, node, sym, errors, a); } node->type.kind = sym->return_type; if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) node->type.struct_name = sym->return_struct_type_name; } // === 表达式类型检查(Visitor 调度器) === // 每个 handler 包装函数: (void* ctx, AstNode* node) → 调用实际 handler #define SEMA_HANDLER(name) \ static void* name##_wrap(void* vctx, AstNode* node) { \ SemaCtx* s = (SemaCtx*)vctx; \ name(node, s->scope, s->errors, s->a); \ return NULL; \ } SEMA_HANDLER(analyze_ident_expr) SEMA_HANDLER(analyze_unary_expr) SEMA_HANDLER(analyze_binary_expr) SEMA_HANDLER(analyze_call_expr) SEMA_HANDLER(analyze_field_access) SEMA_HANDLER(analyze_struct_init) SEMA_HANDLER(analyze_enum_variant) SEMA_HANDLER(analyze_index_expr) SEMA_HANDLER(analyze_method_call) SEMA_HANDLER(analyze_node) // if-expr / block 委托 static AstDispatch sema_dispatch; void analyze_expr_init(void) { sema_dispatch.ctx = NULL; // 由 analyze_expr 每次设置 // 新增表达式节点: 在此注册 handler, 编译器会警告缺失 ast_dispatch_set(&sema_dispatch, AST_IDENT_EXPR, analyze_ident_expr_wrap); ast_dispatch_set(&sema_dispatch, AST_UNARY_EXPR, analyze_unary_expr_wrap); ast_dispatch_set(&sema_dispatch, AST_BINARY_EXPR, analyze_binary_expr_wrap); ast_dispatch_set(&sema_dispatch, AST_CALL_EXPR, analyze_call_expr_wrap); ast_dispatch_set(&sema_dispatch, AST_FIELD_ACCESS, analyze_field_access_wrap); ast_dispatch_set(&sema_dispatch, AST_STRUCT_INIT, analyze_struct_init_wrap); ast_dispatch_set(&sema_dispatch, AST_ENUM_VARIANT, analyze_enum_variant_wrap); ast_dispatch_set(&sema_dispatch, AST_INDEX_EXPR, analyze_index_expr_wrap); ast_dispatch_set(&sema_dispatch, AST_METHOD_CALL, analyze_method_call_wrap); ast_dispatch_set(&sema_dispatch, AST_IF_STMT, analyze_node_wrap); ast_dispatch_set(&sema_dispatch, AST_BLOCK, analyze_node_wrap); } void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { SemaCtx sctx = {scope, errors, a}; sema_dispatch.ctx = &sctx; ast_visit(&sema_dispatch, node); }