560 lines
26 KiB
C
560 lines
26 KiB
C
#include "sema_internal.h"
|
||
|
||
// === 泛型单态化: AST 类型替换 ===
|
||
|
||
// === 类型关系 ===
|
||
TypeKind current_return_type = TYPE_VOID;
|
||
const char* current_return_struct_name = NULL;
|
||
|
||
|
||
// === 类型关系(基于 TypeTable 数据驱动)===
|
||
TypeKind promote(TypeKind a, TypeKind b) {
|
||
// 枚举在算术中视为 i64, char 视为 i32
|
||
if (a == TYPE_ENUM) a = TYPE_I64;
|
||
if (b == TYPE_ENUM) b = TYPE_I64;
|
||
if (a == TYPE_CHAR) a = TYPE_I32;
|
||
if (b == TYPE_CHAR) b = TYPE_I32;
|
||
const TypeDesc* ta = type_lookup(a);
|
||
const TypeDesc* tb = type_lookup(b);
|
||
if (!ta->is_numeric || !tb->is_numeric) return TYPE_ERROR;
|
||
return ta->promote_rank >= tb->promote_rank ? a : b;
|
||
}
|
||
|
||
bool is_numeric(TypeKind t) {
|
||
return type_lookup(t)->is_numeric;
|
||
}
|
||
// 隐式转换: 加宽允许, 同 bit_width 的有/无符号双向允许 (u64↔i64)
|
||
bool can_implicit_convert(TypeKind from, TypeKind to) {
|
||
if (from == to) return true;
|
||
if (from == TYPE_ENUM) from = TYPE_I64;
|
||
if (to == TYPE_ENUM) to = TYPE_I64;
|
||
if (from == to) return true;
|
||
const TypeDesc* tf = type_lookup(from);
|
||
const TypeDesc* tt = type_lookup(to);
|
||
if (!tf->is_numeric || !tt->is_numeric) return false;
|
||
// 同 bit_width: 有/无符号整数双向允许 (u64↔i64, 在 LLVM 中同为 64-bit)
|
||
if (tf->bit_width == tt->bit_width && tf->bit_width >= 32)
|
||
return true;
|
||
// 加宽转换: 有符号→任意, 无符号→仅无符号
|
||
if (tt->promote_rank > tf->promote_rank)
|
||
return tf->is_signed || !tt->is_signed;
|
||
return false;
|
||
}
|
||
|
||
bool is_comparable(TypeKind a, TypeKind b) {
|
||
if (a == b) return true;
|
||
if (a == TYPE_ENUM) a = TYPE_I64;
|
||
if (b == TYPE_ENUM) b = TYPE_I64;
|
||
return type_lookup(a)->is_numeric && type_lookup(b)->is_numeric;
|
||
}
|
||
|
||
// === 向前声明 ===
|
||
void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||
|
||
// === 表达式类型检查辅助函数 ===
|
||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||
|
||
void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
(void)a;
|
||
Symbol* sym = scope_lookup(scope, node->as.ident.name);
|
||
if (!sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的变量 '%s'", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else if (sym->is_type_alias) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 是类型别名,不能作为表达式使用", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else if (sym->kind == SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 是函数,不能作为表达式使用", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = sym->type;
|
||
if (sym->type == TYPE_STRUCT && sym->struct_type_name)
|
||
node->type.struct_name = sym->struct_type_name;
|
||
if (sym->type == TYPE_ARRAY) {
|
||
node->type.element_type = sym->array_element_type;
|
||
node->type.element_struct_name = sym->array_element_struct_name;
|
||
node->type.array_size = sym->array_size;
|
||
}
|
||
}
|
||
}
|
||
|
||
void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.unary.operand, scope, errors, a);
|
||
TypeKind inner = node->as.unary.operand->type.kind;
|
||
if (node->as.unary.op == OP_NEG) {
|
||
if (!is_numeric(inner)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"一元 '-' 只能用于数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = inner;
|
||
}
|
||
} else { // OP_NOT
|
||
if (inner != TYPE_BOOL) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'!' 只能用于布尔类型,得到 '%s'", type_name(inner));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
}
|
||
}
|
||
|
||
void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.binary.left, scope, errors, a);
|
||
analyze_expr(node->as.binary.right, scope, errors, a);
|
||
TypeKind l = node->as.binary.left->type.kind;
|
||
TypeKind r = node->as.binary.right->type.kind;
|
||
if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
|
||
switch (node->as.binary.op) {
|
||
case OP_ADD:
|
||
if (l == TYPE_STR || r == TYPE_STR) {
|
||
if (l != TYPE_STR || r != TYPE_STR) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'",
|
||
type_name(l), type_name(r));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_STR;
|
||
}
|
||
} else if (!is_numeric(l) || !is_numeric(r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = promote(l, r);
|
||
}
|
||
break;
|
||
case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD:
|
||
if (!is_numeric(l) || !is_numeric(r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = promote(l, r);
|
||
}
|
||
break;
|
||
case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE:
|
||
if (!is_comparable(l, r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
break;
|
||
case OP_AND: case OP_OR:
|
||
if (l != TYPE_BOOL || r != TYPE_BOOL) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "逻辑运算需要布尔类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
break;
|
||
default: break;
|
||
}
|
||
}
|
||
|
||
// 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用)
|
||
bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname,
|
||
size_t idx, AstNode* call_node, Symbol* fn_sym,
|
||
ErrorList* errors, Arena* a) {
|
||
(void)a;
|
||
TypeKind actual = arg->type.kind;
|
||
if (actual == TYPE_ERROR) return false;
|
||
if (expected == TYPE_STRUCT && expected_sname) {
|
||
// 检查是否是泛型类型参数(匹配则接受任意类型)
|
||
if (fn_sym && fn_sym->type_params) {
|
||
for (size_t t = 0; t < fn_sym->type_param_count; t++) {
|
||
if (strcmp(expected_sname, fn_sym->type_params[t]) == 0)
|
||
return true; // 泛型参数,接受任意类型
|
||
}
|
||
}
|
||
const char* actual_name = arg->type.struct_name;
|
||
if (actual != TYPE_STRUCT || !actual_name ||
|
||
strcmp(actual_name, expected_sname) != 0) {
|
||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
idx + 1, expected_sname ? expected_sname : "struct",
|
||
actual_name ? actual_name : type_name(actual));
|
||
}
|
||
return false;
|
||
}
|
||
if (actual == expected) return false;
|
||
if (expected == TYPE_I64 && actual == TYPE_ENUM) return false;
|
||
if (can_implicit_convert(actual, expected)) return false;
|
||
if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR
|
||
&& (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false;
|
||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
idx + 1, type_name(expected), type_name(actual));
|
||
return false;
|
||
}
|
||
|
||
// 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用)
|
||
bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
|
||
ErrorList* errors, const char* call_name) {
|
||
AstNode** args = node->as.call.args;
|
||
const char** arg_names = node->as.call.arg_names;
|
||
size_t arg_count = node->as.call.arg_count;
|
||
if (!arg_names) return true;
|
||
AstNode* reordered[16] = {0};
|
||
for (size_t i = 0; i < arg_count; i++) {
|
||
if (arg_names[i]) {
|
||
bool found = false;
|
||
for (size_t j = param_offset; j < sym->param_count; j++) {
|
||
if (sym->param_names && sym->param_names[j] &&
|
||
strcmp(arg_names[i], sym->param_names[j]) == 0) {
|
||
reordered[j - param_offset] = args[i];
|
||
found = true; break;
|
||
}
|
||
}
|
||
if (!found) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]);
|
||
return false;
|
||
}
|
||
} else {
|
||
reordered[i] = args[i];
|
||
}
|
||
}
|
||
memcpy(args, reordered, arg_count * sizeof(AstNode*));
|
||
return true;
|
||
}
|
||
|
||
void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
Symbol* sym = scope_lookup(scope, node->as.call.name);
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的函数 '%s'", node->as.call.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
return;
|
||
}
|
||
if (node->as.call.arg_count != sym->param_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"函数 '%s' 需要 %zu 个参数,但提供了 %zu 个",
|
||
node->as.call.name, sym->param_count, node->as.call.arg_count);
|
||
node->type.kind = TYPE_ERROR;
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
return;
|
||
}
|
||
if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) {
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i],
|
||
sym->param_struct_names ? sym->param_struct_names[i] : NULL,
|
||
i, node, sym, errors, a);
|
||
// 泛型单态化: 创建具象化函数副本并注册
|
||
if (is_generic_param && sym->type_params && sym->type_param_count > 0) {
|
||
TypeKind concrete = node->as.call.args[i]->type.kind;
|
||
const char* concrete_sn = node->as.call.args[i]->type.struct_name;
|
||
// 构造 mangled 名: fn$concrete_type
|
||
const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete);
|
||
int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1;
|
||
char* mname = arena_alloc_impl(a, mname_len);
|
||
snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name);
|
||
// 检查是否已存在
|
||
Symbol* existing = scope_lookup(scope, mname);
|
||
if (!existing && g_program) {
|
||
// 查找原始泛型函数 AST 节点
|
||
AstNode* generic_fn = NULL;
|
||
for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) {
|
||
if (strcmp(g_program->as.program.functions[fn_i]->as.function.name,
|
||
node->as.call.name) == 0) {
|
||
generic_fn = g_program->as.program.functions[fn_i];
|
||
break;
|
||
}
|
||
}
|
||
if (generic_fn && mono_count < 256) {
|
||
// 创建浅拷贝(共享 body,subst_ast_types 修改类型标注)
|
||
AstNode* mono_fn = ast_make_function(a, mname,
|
||
generic_fn->as.function.params,
|
||
generic_fn->as.function.param_count,
|
||
generic_fn->as.function.return_type,
|
||
generic_fn->as.function.return_struct_type_name,
|
||
generic_fn->as.function.body,
|
||
false, NULL, 0,
|
||
generic_fn->loc);
|
||
// 类型替换: T → concrete
|
||
subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn);
|
||
// 注册到队列
|
||
mono_queue[mono_count++] = mono_fn;
|
||
// 注册符号(后续分析会处理函数体)
|
||
TypeKind* mpts = mono_fn->as.function.param_count > 0
|
||
? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL;
|
||
for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) {
|
||
mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type;
|
||
}
|
||
scope_insert_function(scope, a, mname,
|
||
mono_fn->as.function.return_type,
|
||
mono_fn->as.function.return_struct_type_name,
|
||
mpts, NULL, NULL,
|
||
mono_fn->as.function.param_count, NULL, 0);
|
||
}
|
||
}
|
||
// 重定向调用到单态化函数
|
||
node->as.call.name = mname;
|
||
sym = scope_lookup(scope, mname);
|
||
if (!sym) { node->type.kind = TYPE_ERROR; return; }
|
||
}
|
||
}
|
||
node->type.kind = sym->return_type;
|
||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||
node->type.struct_name = sym->return_struct_type_name;
|
||
}
|
||
|
||
void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.field_access.object, scope, errors, a);
|
||
AstNode* obj = node->as.field_access.object;
|
||
if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
if (obj->type.kind != TYPE_STRUCT) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 不是结构体,不能访问字段 '%s'",
|
||
type_name(obj->type.kind), node->as.field_access.field);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
const char* struct_name = obj->type.struct_name;
|
||
if (!struct_name) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "无法确定结构体类型");
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
Symbol* struct_sym = scope_lookup_struct(scope, struct_name);
|
||
if (!struct_sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的结构体 '%s'", struct_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
int fi = scope_struct_field_index(struct_sym, node->as.field_access.field);
|
||
if (fi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
node->type.kind = struct_sym->struct_field_types[fi];
|
||
node->as.field_access.field_index = fi;
|
||
if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names &&
|
||
struct_sym->struct_field_struct_names[fi])
|
||
node->type.struct_name = struct_sym->struct_field_struct_names[fi];
|
||
}
|
||
|
||
void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
const char* resolved = node->as.struct_init.type_name;
|
||
Symbol* struct_sym = scope_lookup_struct(scope, resolved);
|
||
if (!struct_sym) {
|
||
Symbol* alias_sym = scope_lookup(scope, resolved);
|
||
if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) {
|
||
resolved = alias_sym->struct_type_name;
|
||
struct_sym = scope_lookup_struct(scope, resolved);
|
||
node->as.struct_init.type_name = resolved;
|
||
}
|
||
}
|
||
if (!struct_sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的结构体类型 '%s'", node->as.struct_init.type_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (node->as.struct_init.field_count != struct_sym->struct_field_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 有 %zu 个字段,但提供了 %zu 个",
|
||
node->as.struct_init.type_name,
|
||
struct_sym->struct_field_count, node->as.struct_init.field_count);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||
const char* fname = node->as.struct_init.field_names[i];
|
||
AstNode* fval = node->as.struct_init.field_values[i];
|
||
analyze_expr(fval, scope, errors, a);
|
||
int fi = scope_struct_field_index(struct_sym, fname);
|
||
if (fi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname);
|
||
node->type.kind = TYPE_ERROR; continue;
|
||
}
|
||
TypeKind expected = struct_sym->struct_field_types[fi];
|
||
TypeKind actual = fval->type.kind;
|
||
if (actual != TYPE_ERROR && actual != expected)
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'",
|
||
fname, type_name(expected), type_name(actual));
|
||
}
|
||
if (node->type.kind != TYPE_ERROR) {
|
||
node->type.kind = TYPE_STRUCT;
|
||
node->type.struct_name = resolved;
|
||
}
|
||
}
|
||
|
||
void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
(void)a;
|
||
Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name);
|
||
if (!enum_sym || enum_sym->kind != SYM_ENUM) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的枚举 '%s'", node->as.enum_variant.enum_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name);
|
||
if (vi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"枚举 '%s' 没有变体 '%s'",
|
||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
node->as.enum_variant.variant_index = vi;
|
||
// ADT: 检查 payload
|
||
TypeKind expected_pt = TYPE_VOID;
|
||
if (enum_sym->variant_payload_types)
|
||
expected_pt = enum_sym->variant_payload_types[vi];
|
||
if (node->as.enum_variant.payload) {
|
||
if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"枚举变体 '%s::%s' 不接受 payload",
|
||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
analyze_expr(node->as.enum_variant.payload, scope, errors, a);
|
||
TypeKind actual = node->as.enum_variant.payload->type.kind;
|
||
if (actual != TYPE_ERROR && actual != expected_pt) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'",
|
||
type_name(expected_pt), type_name(actual));
|
||
}
|
||
}
|
||
node->type.kind = TYPE_ENUM;
|
||
}
|
||
|
||
void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.index_expr.array, scope, errors, a);
|
||
analyze_expr(node->as.index_expr.index, scope, errors, a);
|
||
AstNode* arr = node->as.index_expr.array;
|
||
AstNode* idx = node->as.index_expr.index;
|
||
if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
if (arr->type.kind != TYPE_ARRAY) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 不支持索引操作", type_name(arr->type.kind));
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
if (idx->type.kind != TYPE_I64) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
node->type.kind = arr->type.element_type;
|
||
node->type.struct_name = arr->type.element_struct_name;
|
||
}
|
||
|
||
void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.method_call.receiver, scope, errors, a);
|
||
const char* recv_struct = node->as.method_call.receiver->type.struct_name;
|
||
if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"只有结构体类型支持方法调用");
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
char mangled[256];
|
||
snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct,
|
||
node->as.method_call.method_name);
|
||
Symbol* sym = scope_lookup(scope, mangled);
|
||
// trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾且以 StructName 开头的符号
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
char suffix[256];
|
||
snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name);
|
||
size_t suf_len = strlen(suffix);
|
||
size_t recv_len = strlen(recv_struct);
|
||
for (const Scope* sc = scope; sc; sc = sc->parent) {
|
||
for (Symbol* s = sc->head; s; s = s->next) {
|
||
if (s->kind == SYM_FUNCTION) {
|
||
size_t name_len = strlen(s->name);
|
||
if (name_len > suf_len + recv_len
|
||
&& strncmp(s->name, recv_struct, recv_len) == 0
|
||
&& strcmp(s->name + name_len - suf_len, suffix) == 0) {
|
||
sym = s;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
if (sym) break;
|
||
}
|
||
}
|
||
// 更新 method_name 为符号的实际名称(codegen 需要通过它找到 LLVM 函数)
|
||
if (sym && sym->kind == SYM_FUNCTION) {
|
||
node->as.method_call.method_name = sym->name;
|
||
}
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有方法 '%s'", recv_struct,
|
||
node->as.method_call.method_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (node->as.method_call.arg_count + 1 != sym->param_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"方法 '%s' 需要 %zu 个参数,提供了 %zu 个",
|
||
node->as.method_call.method_name,
|
||
sym->param_count > 0 ? sym->param_count - 1 : 0,
|
||
node->as.method_call.arg_count);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) {
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||
analyze_expr(node->as.method_call.args[i], scope, errors, a);
|
||
check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1],
|
||
sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL,
|
||
i, node, sym, errors, a);
|
||
}
|
||
node->type.kind = sym->return_type;
|
||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||
node->type.struct_name = sym->return_struct_type_name;
|
||
}
|
||
|
||
// === 表达式类型检查(Visitor 调度器) ===
|
||
// 每个 handler 包装函数: (void* ctx, AstNode* node) → 调用实际 handler
|
||
#define SEMA_HANDLER(name) \
|
||
static void* name##_wrap(void* vctx, AstNode* node) { \
|
||
SemaCtx* s = (SemaCtx*)vctx; \
|
||
name(node, s->scope, s->errors, s->a); \
|
||
return NULL; \
|
||
}
|
||
|
||
SEMA_HANDLER(analyze_ident_expr)
|
||
SEMA_HANDLER(analyze_unary_expr)
|
||
SEMA_HANDLER(analyze_binary_expr)
|
||
SEMA_HANDLER(analyze_call_expr)
|
||
SEMA_HANDLER(analyze_field_access)
|
||
SEMA_HANDLER(analyze_struct_init)
|
||
SEMA_HANDLER(analyze_enum_variant)
|
||
SEMA_HANDLER(analyze_index_expr)
|
||
SEMA_HANDLER(analyze_method_call)
|
||
SEMA_HANDLER(analyze_node) // if-expr / block 委托
|
||
|
||
static AstDispatch sema_dispatch;
|
||
|
||
void analyze_expr_init(void) {
|
||
sema_dispatch.ctx = NULL; // 由 analyze_expr 每次设置
|
||
// 新增表达式节点: 在此注册 handler, 编译器会警告缺失
|
||
ast_dispatch_set(&sema_dispatch, AST_IDENT_EXPR, analyze_ident_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_UNARY_EXPR, analyze_unary_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_BINARY_EXPR, analyze_binary_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_CALL_EXPR, analyze_call_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_FIELD_ACCESS, analyze_field_access_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_STRUCT_INIT, analyze_struct_init_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_ENUM_VARIANT, analyze_enum_variant_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_INDEX_EXPR, analyze_index_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_METHOD_CALL, analyze_method_call_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_IF_STMT, analyze_node_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_BLOCK, analyze_node_wrap);
|
||
}
|
||
|
||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
SemaCtx sctx = {scope, errors, a};
|
||
sema_dispatch.ctx = &sctx;
|
||
ast_visit(&sema_dispatch, node);
|
||
}
|