06d80f441a
fn(x: T) -> R { body } 作为表达式, 可赋值给变量并间接调用。
全流水线实现:
- Parser: TOK_FN 前缀 → AST_LAMBDA 节点
- Sema: 自动生成 __lambda_N 顶层函数 + 符号注册
- Sema: analyze_call_expr 支持 TYPE_CLOSURE 变量调用
- Codegen: lambda 表达式返回函数指针(i64), 调用点载入+IntToPtr+间接call
- VarEntry.closure_fn 追踪闭包变量对应的生成函数
限制(MVP v0.1): 非捕获 lambda, 返回类型固定 i64
+6 sema 测试 + 1 集成测试, 209 测试全部通过
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
638 lines
29 KiB
C
638 lines
29 KiB
C
#include "sema_internal.h"
|
||
|
||
// === 泛型单态化: AST 类型替换 ===
|
||
|
||
// === 类型关系 ===
|
||
TypeKind current_return_type = TYPE_VOID;
|
||
const char* current_return_struct_name = NULL;
|
||
|
||
|
||
// === 类型关系(基于 TypeTable 数据驱动)===
|
||
TypeKind promote(TypeKind a, TypeKind b) {
|
||
// 枚举在算术中视为 i64, char 视为 i32
|
||
if (a == TYPE_ENUM) a = TYPE_I64;
|
||
if (b == TYPE_ENUM) b = TYPE_I64;
|
||
if (a == TYPE_CHAR) a = TYPE_I32;
|
||
if (b == TYPE_CHAR) b = TYPE_I32;
|
||
const TypeDesc* ta = type_lookup(a);
|
||
const TypeDesc* tb = type_lookup(b);
|
||
if (!ta->is_numeric || !tb->is_numeric) return TYPE_ERROR;
|
||
return ta->promote_rank >= tb->promote_rank ? a : b;
|
||
}
|
||
|
||
bool is_numeric(TypeKind t) {
|
||
return type_lookup(t)->is_numeric;
|
||
}
|
||
// 隐式转换: 加宽允许, 同 bit_width 的有/无符号双向允许 (u64↔i64)
|
||
bool can_implicit_convert(TypeKind from, TypeKind to) {
|
||
if (from == to) return true;
|
||
if (from == TYPE_ENUM) from = TYPE_I64;
|
||
if (to == TYPE_ENUM) to = TYPE_I64;
|
||
if (from == to) return true;
|
||
const TypeDesc* tf = type_lookup(from);
|
||
const TypeDesc* tt = type_lookup(to);
|
||
if (!tf->is_numeric || !tt->is_numeric) return false;
|
||
// 同 bit_width: 有/无符号整数双向允许 (u64↔i64, 在 LLVM 中同为 64-bit)
|
||
if (tf->bit_width == tt->bit_width && tf->bit_width >= 32)
|
||
return true;
|
||
// 加宽转换: 有符号→任意, 无符号→仅无符号
|
||
if (tt->promote_rank > tf->promote_rank)
|
||
return tf->is_signed || !tt->is_signed;
|
||
return false;
|
||
}
|
||
|
||
bool is_comparable(TypeKind a, TypeKind b) {
|
||
if (a == b) return true;
|
||
if (a == TYPE_ENUM) a = TYPE_I64;
|
||
if (b == TYPE_ENUM) b = TYPE_I64;
|
||
return type_lookup(a)->is_numeric && type_lookup(b)->is_numeric;
|
||
}
|
||
|
||
// === 向前声明 ===
|
||
void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||
|
||
// === 表达式类型检查辅助函数 ===
|
||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||
|
||
void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
(void)a;
|
||
Symbol* sym = scope_lookup(scope, node->as.ident.name);
|
||
if (!sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的变量 '%s'", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else if (sym->is_type_alias) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 是类型别名,不能作为表达式使用", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else if (sym->kind == SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 是函数,不能作为表达式使用", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = sym->type;
|
||
if (sym->type == TYPE_STRUCT && sym->struct_type_name)
|
||
node->type.struct_name = sym->struct_type_name;
|
||
if (sym->type == TYPE_ARRAY) {
|
||
node->type.element_type = sym->array_element_type;
|
||
node->type.element_struct_name = sym->array_element_struct_name;
|
||
node->type.array_size = sym->array_size;
|
||
}
|
||
}
|
||
}
|
||
|
||
void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.unary.operand, scope, errors, a);
|
||
TypeKind inner = node->as.unary.operand->type.kind;
|
||
if (node->as.unary.op == OP_NEG) {
|
||
if (!is_numeric(inner)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"一元 '-' 只能用于数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = inner;
|
||
}
|
||
} else { // OP_NOT
|
||
if (inner != TYPE_BOOL) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'!' 只能用于布尔类型,得到 '%s'", type_name(inner));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
}
|
||
}
|
||
|
||
void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.binary.left, scope, errors, a);
|
||
analyze_expr(node->as.binary.right, scope, errors, a);
|
||
TypeKind l = node->as.binary.left->type.kind;
|
||
TypeKind r = node->as.binary.right->type.kind;
|
||
if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
|
||
switch (node->as.binary.op) {
|
||
case OP_ADD:
|
||
if (l == TYPE_STR || r == TYPE_STR) {
|
||
if (l != TYPE_STR || r != TYPE_STR) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'",
|
||
type_name(l), type_name(r));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_STR;
|
||
}
|
||
} else if (!is_numeric(l) || !is_numeric(r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = promote(l, r);
|
||
}
|
||
break;
|
||
case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD:
|
||
if (!is_numeric(l) || !is_numeric(r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = promote(l, r);
|
||
}
|
||
break;
|
||
case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE:
|
||
if (!is_comparable(l, r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
break;
|
||
case OP_AND: case OP_OR:
|
||
if (l != TYPE_BOOL || r != TYPE_BOOL) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "逻辑运算需要布尔类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
break;
|
||
default: break;
|
||
}
|
||
}
|
||
|
||
// 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用)
|
||
bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname,
|
||
size_t idx, AstNode* call_node, Symbol* fn_sym,
|
||
ErrorList* errors, Arena* a) {
|
||
(void)a;
|
||
TypeKind actual = arg->type.kind;
|
||
if (actual == TYPE_ERROR) return false;
|
||
if (expected == TYPE_STRUCT && expected_sname) {
|
||
// 检查是否是泛型类型参数(匹配则接受任意类型)
|
||
if (fn_sym && fn_sym->type_params) {
|
||
for (size_t t = 0; t < fn_sym->type_param_count; t++) {
|
||
if (strcmp(expected_sname, fn_sym->type_params[t]) == 0)
|
||
return true; // 泛型参数,接受任意类型
|
||
}
|
||
}
|
||
const char* actual_name = arg->type.struct_name;
|
||
if (actual != TYPE_STRUCT || !actual_name ||
|
||
strcmp(actual_name, expected_sname) != 0) {
|
||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
idx + 1, expected_sname ? expected_sname : "struct",
|
||
actual_name ? actual_name : type_name(actual));
|
||
}
|
||
return false;
|
||
}
|
||
if (actual == expected) return false;
|
||
if (expected == TYPE_I64 && actual == TYPE_ENUM) return false;
|
||
if (can_implicit_convert(actual, expected)) return false;
|
||
if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR
|
||
&& (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false;
|
||
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
idx + 1, type_name(expected), type_name(actual));
|
||
return false;
|
||
}
|
||
|
||
// 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用)
|
||
bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
|
||
ErrorList* errors, const char* call_name) {
|
||
AstNode** args = node->as.call.args;
|
||
const char** arg_names = node->as.call.arg_names;
|
||
size_t arg_count = node->as.call.arg_count;
|
||
if (!arg_names) return true;
|
||
AstNode* reordered[16] = {0};
|
||
for (size_t i = 0; i < arg_count; i++) {
|
||
if (arg_names[i]) {
|
||
bool found = false;
|
||
for (size_t j = param_offset; j < sym->param_count; j++) {
|
||
if (sym->param_names && sym->param_names[j] &&
|
||
strcmp(arg_names[i], sym->param_names[j]) == 0) {
|
||
reordered[j - param_offset] = args[i];
|
||
found = true; break;
|
||
}
|
||
}
|
||
if (!found) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]);
|
||
return false;
|
||
}
|
||
} else {
|
||
reordered[i] = args[i];
|
||
}
|
||
}
|
||
memcpy(args, reordered, arg_count * sizeof(AstNode*));
|
||
return true;
|
||
}
|
||
|
||
void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
Symbol* sym = scope_lookup(scope, node->as.call.name);
|
||
// 闭包调用: 变量类型为 TYPE_CLOSURE
|
||
if (sym && sym->kind == SYM_VARIABLE && sym->type == TYPE_CLOSURE) {
|
||
// 暂不做参数类型检查(MVP), 只分析参数表达式
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
node->type.kind = TYPE_I64; // 默认返回 i64(MVP 限制)
|
||
return;
|
||
}
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的函数 '%s'", node->as.call.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
return;
|
||
}
|
||
if (node->as.call.arg_count != sym->param_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"函数 '%s' 需要 %zu 个参数,但提供了 %zu 个",
|
||
node->as.call.name, sym->param_count, node->as.call.arg_count);
|
||
node->type.kind = TYPE_ERROR;
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
return;
|
||
}
|
||
if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) {
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i],
|
||
sym->param_struct_names ? sym->param_struct_names[i] : NULL,
|
||
i, node, sym, errors, a);
|
||
// 泛型单态化: 创建具象化函数副本并注册
|
||
if (is_generic_param && sym->type_params && sym->type_param_count > 0) {
|
||
TypeKind concrete = node->as.call.args[i]->type.kind;
|
||
const char* concrete_sn = node->as.call.args[i]->type.struct_name;
|
||
// 构造 mangled 名: fn$concrete_type
|
||
const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete);
|
||
int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1;
|
||
char* mname = arena_alloc_impl(a, mname_len);
|
||
snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name);
|
||
// 检查是否已存在
|
||
Symbol* existing = scope_lookup(scope, mname);
|
||
if (!existing && g_program) {
|
||
// 查找原始泛型函数 AST 节点
|
||
AstNode* generic_fn = NULL;
|
||
for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) {
|
||
if (strcmp(g_program->as.program.functions[fn_i]->as.function.name,
|
||
node->as.call.name) == 0) {
|
||
generic_fn = g_program->as.program.functions[fn_i];
|
||
break;
|
||
}
|
||
}
|
||
if (generic_fn && mono_count < 256) {
|
||
// 创建浅拷贝(共享 body,subst_ast_types 修改类型标注)
|
||
AstNode* mono_fn = ast_make_function(a, mname,
|
||
generic_fn->as.function.params,
|
||
generic_fn->as.function.param_count,
|
||
generic_fn->as.function.return_type,
|
||
generic_fn->as.function.return_struct_type_name,
|
||
generic_fn->as.function.body,
|
||
false, NULL, 0,
|
||
generic_fn->loc);
|
||
// 类型替换: T → concrete
|
||
subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn);
|
||
// 注册到队列
|
||
mono_queue[mono_count++] = mono_fn;
|
||
// 注册符号(后续分析会处理函数体)
|
||
TypeKind* mpts = mono_fn->as.function.param_count > 0
|
||
? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL;
|
||
for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) {
|
||
mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type;
|
||
}
|
||
scope_insert_function(scope, a, mname,
|
||
mono_fn->as.function.return_type,
|
||
mono_fn->as.function.return_struct_type_name,
|
||
mpts, NULL, NULL,
|
||
mono_fn->as.function.param_count, NULL, 0);
|
||
}
|
||
}
|
||
// 重定向调用到单态化函数
|
||
node->as.call.name = mname;
|
||
sym = scope_lookup(scope, mname);
|
||
if (!sym) { node->type.kind = TYPE_ERROR; return; }
|
||
}
|
||
}
|
||
node->type.kind = sym->return_type;
|
||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||
node->type.struct_name = sym->return_struct_type_name;
|
||
}
|
||
|
||
void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.field_access.object, scope, errors, a);
|
||
AstNode* obj = node->as.field_access.object;
|
||
if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
if (obj->type.kind != TYPE_STRUCT) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 不是结构体,不能访问字段 '%s'",
|
||
type_name(obj->type.kind), node->as.field_access.field);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
const char* struct_name = obj->type.struct_name;
|
||
if (!struct_name) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "无法确定结构体类型");
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
Symbol* struct_sym = scope_lookup_struct(scope, struct_name);
|
||
if (!struct_sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的结构体 '%s'", struct_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
int fi = scope_struct_field_index(struct_sym, node->as.field_access.field);
|
||
if (fi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
node->type.kind = struct_sym->struct_field_types[fi];
|
||
node->as.field_access.field_index = fi;
|
||
if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names &&
|
||
struct_sym->struct_field_struct_names[fi])
|
||
node->type.struct_name = struct_sym->struct_field_struct_names[fi];
|
||
}
|
||
|
||
void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
const char* resolved = node->as.struct_init.type_name;
|
||
Symbol* struct_sym = scope_lookup_struct(scope, resolved);
|
||
if (!struct_sym) {
|
||
Symbol* alias_sym = scope_lookup(scope, resolved);
|
||
if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) {
|
||
resolved = alias_sym->struct_type_name;
|
||
struct_sym = scope_lookup_struct(scope, resolved);
|
||
node->as.struct_init.type_name = resolved;
|
||
}
|
||
}
|
||
if (!struct_sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的结构体类型 '%s'", node->as.struct_init.type_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (node->as.struct_init.field_count != struct_sym->struct_field_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 有 %zu 个字段,但提供了 %zu 个",
|
||
node->as.struct_init.type_name,
|
||
struct_sym->struct_field_count, node->as.struct_init.field_count);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||
const char* fname = node->as.struct_init.field_names[i];
|
||
AstNode* fval = node->as.struct_init.field_values[i];
|
||
analyze_expr(fval, scope, errors, a);
|
||
int fi = scope_struct_field_index(struct_sym, fname);
|
||
if (fi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname);
|
||
node->type.kind = TYPE_ERROR; continue;
|
||
}
|
||
TypeKind expected = struct_sym->struct_field_types[fi];
|
||
TypeKind actual = fval->type.kind;
|
||
if (actual != TYPE_ERROR && actual != expected)
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'",
|
||
fname, type_name(expected), type_name(actual));
|
||
}
|
||
if (node->type.kind != TYPE_ERROR) {
|
||
node->type.kind = TYPE_STRUCT;
|
||
node->type.struct_name = resolved;
|
||
}
|
||
}
|
||
|
||
void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
(void)a;
|
||
Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name);
|
||
if (!enum_sym || enum_sym->kind != SYM_ENUM) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的枚举 '%s'", node->as.enum_variant.enum_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name);
|
||
if (vi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"枚举 '%s' 没有变体 '%s'",
|
||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
node->as.enum_variant.variant_index = vi;
|
||
// ADT: 检查 payload
|
||
TypeKind expected_pt = TYPE_VOID;
|
||
if (enum_sym->variant_payload_types)
|
||
expected_pt = enum_sym->variant_payload_types[vi];
|
||
if (node->as.enum_variant.payload) {
|
||
if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"枚举变体 '%s::%s' 不接受 payload",
|
||
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
analyze_expr(node->as.enum_variant.payload, scope, errors, a);
|
||
TypeKind actual = node->as.enum_variant.payload->type.kind;
|
||
if (actual != TYPE_ERROR && actual != expected_pt) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'",
|
||
type_name(expected_pt), type_name(actual));
|
||
}
|
||
}
|
||
node->type.kind = TYPE_ENUM;
|
||
}
|
||
|
||
void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.index_expr.array, scope, errors, a);
|
||
analyze_expr(node->as.index_expr.index, scope, errors, a);
|
||
AstNode* arr = node->as.index_expr.array;
|
||
AstNode* idx = node->as.index_expr.index;
|
||
if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
if (arr->type.kind != TYPE_ARRAY) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 不支持索引操作", type_name(arr->type.kind));
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
|
||
if (idx->type.kind != TYPE_I64) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
node->type.kind = arr->type.element_type;
|
||
node->type.struct_name = arr->type.element_struct_name;
|
||
}
|
||
|
||
void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.method_call.receiver, scope, errors, a);
|
||
const char* recv_struct = node->as.method_call.receiver->type.struct_name;
|
||
if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"只有结构体类型支持方法调用");
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
char mangled[256];
|
||
snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct,
|
||
node->as.method_call.method_name);
|
||
Symbol* sym = scope_lookup(scope, mangled);
|
||
// trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾且以 StructName 开头的符号
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
char suffix[256];
|
||
snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name);
|
||
size_t suf_len = strlen(suffix);
|
||
size_t recv_len = strlen(recv_struct);
|
||
for (const Scope* sc = scope; sc; sc = sc->parent) {
|
||
for (Symbol* s = sc->head; s; s = s->next) {
|
||
if (s->kind == SYM_FUNCTION) {
|
||
size_t name_len = strlen(s->name);
|
||
if (name_len > suf_len + recv_len
|
||
&& strncmp(s->name, recv_struct, recv_len) == 0
|
||
&& strcmp(s->name + name_len - suf_len, suffix) == 0) {
|
||
sym = s;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
if (sym) break;
|
||
}
|
||
}
|
||
// 更新 method_name 为符号的实际名称(codegen 需要通过它找到 LLVM 函数)
|
||
if (sym && sym->kind == SYM_FUNCTION) {
|
||
node->as.method_call.method_name = sym->name;
|
||
}
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有方法 '%s'", recv_struct,
|
||
node->as.method_call.method_name);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (node->as.method_call.arg_count + 1 != sym->param_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"方法 '%s' 需要 %zu 个参数,提供了 %zu 个",
|
||
node->as.method_call.method_name,
|
||
sym->param_count > 0 ? sym->param_count - 1 : 0,
|
||
node->as.method_call.arg_count);
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) {
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||
analyze_expr(node->as.method_call.args[i], scope, errors, a);
|
||
check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1],
|
||
sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL,
|
||
i, node, sym, errors, a);
|
||
}
|
||
node->type.kind = sym->return_type;
|
||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
|
||
node->type.struct_name = sym->return_struct_type_name;
|
||
}
|
||
|
||
// === 表达式类型检查(Visitor 调度器) ===
|
||
// 每个 handler 包装函数: (void* ctx, AstNode* node) → 调用实际 handler
|
||
#define SEMA_HANDLER(name) \
|
||
static void* name##_wrap(void* vctx, AstNode* node) { \
|
||
SemaCtx* s = (SemaCtx*)vctx; \
|
||
name(node, s->scope, s->errors, s->a); \
|
||
return NULL; \
|
||
}
|
||
|
||
// === lambda 表达式分析 ===
|
||
static int lambda_counter = 0;
|
||
|
||
void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
lambda_counter++;
|
||
int name_len = snprintf(NULL, 0, "__lambda_%d", lambda_counter) + 1;
|
||
char* gen_name = arena_alloc_impl(a, name_len);
|
||
snprintf(gen_name, name_len, "__lambda_%d", lambda_counter);
|
||
node->as.lambda.generated_name = gen_name;
|
||
|
||
// 分析 lambda 体(参数作用域)
|
||
Scope* lambda_scope = scope_new(a, scope);
|
||
for (size_t i = 0; i < node->as.lambda.param_count; i++) {
|
||
AstNode* p = node->as.lambda.params[i];
|
||
scope_insert(lambda_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type);
|
||
}
|
||
TypeKind saved_ret = current_return_type;
|
||
const char* saved_ret_sn = current_return_struct_name;
|
||
current_return_type = node->as.lambda.return_type;
|
||
current_return_struct_name = node->as.lambda.return_struct_type_name;
|
||
analyze_node(node->as.lambda.body, lambda_scope, errors, a);
|
||
current_return_type = saved_ret;
|
||
current_return_struct_name = saved_ret_sn;
|
||
|
||
// 创建顶层函数 AST 节点, 加入队列供 codegen 使用
|
||
AstNode* fn = ast_make_function(a, gen_name,
|
||
node->as.lambda.params, node->as.lambda.param_count,
|
||
node->as.lambda.return_type,
|
||
node->as.lambda.return_struct_type_name,
|
||
node->as.lambda.body, false, NULL, 0, node->loc);
|
||
if (lambda_count < 256)
|
||
lambda_queue[lambda_count++] = fn;
|
||
|
||
// 注册函数符号(支持递归调用自身)
|
||
TypeKind* pts = node->as.lambda.param_count > 0
|
||
? arena_alloc_impl(a, node->as.lambda.param_count * sizeof(TypeKind)) : NULL;
|
||
for (size_t i = 0; i < node->as.lambda.param_count; i++)
|
||
pts[i] = node->as.lambda.params[i]->as.parameter.type;
|
||
scope_insert_function(scope, a, gen_name,
|
||
node->as.lambda.return_type,
|
||
node->as.lambda.return_struct_type_name,
|
||
pts, NULL, NULL, node->as.lambda.param_count, NULL, 0);
|
||
|
||
node->type.kind = TYPE_CLOSURE;
|
||
}
|
||
|
||
SEMA_HANDLER(analyze_lambda)
|
||
SEMA_HANDLER(analyze_ident_expr)
|
||
SEMA_HANDLER(analyze_unary_expr)
|
||
SEMA_HANDLER(analyze_binary_expr)
|
||
SEMA_HANDLER(analyze_call_expr)
|
||
SEMA_HANDLER(analyze_field_access)
|
||
SEMA_HANDLER(analyze_struct_init)
|
||
SEMA_HANDLER(analyze_enum_variant)
|
||
SEMA_HANDLER(analyze_index_expr)
|
||
SEMA_HANDLER(analyze_method_call)
|
||
SEMA_HANDLER(analyze_node) // if-expr / block 委托
|
||
|
||
static AstDispatch sema_dispatch;
|
||
|
||
static void analyze_list_comp(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
analyze_expr(node->as.list_comp.array, scope, errors, a);
|
||
TypeInfo* arr_ti = &node->as.list_comp.array->type;
|
||
if (arr_ti->kind != TYPE_ARRAY) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"列表推导式需要数组类型, 得到 '%s'", type_name(arr_ti->kind));
|
||
node->type.kind = TYPE_ERROR; return;
|
||
}
|
||
Scope* lc_scope = scope_new(a, scope);
|
||
TypeKind elem_k = arr_ti->element_type;
|
||
Symbol* var_sym = scope_insert(lc_scope, a, node->as.list_comp.var_name,
|
||
SYM_VARIABLE, elem_k);
|
||
if (var_sym) var_sym->struct_type_name = arr_ti->element_struct_name;
|
||
analyze_expr(node->as.list_comp.map_expr, lc_scope, errors, a);
|
||
node->type.kind = TYPE_ARRAY;
|
||
node->type.element_type = arr_ti->element_type;
|
||
node->type.element_struct_name = arr_ti->element_struct_name;
|
||
node->type.array_size = arr_ti->array_size;
|
||
}
|
||
SEMA_HANDLER(analyze_list_comp)
|
||
|
||
void analyze_expr_init(void) {
|
||
sema_dispatch.ctx = NULL; // 由 analyze_expr 每次设置
|
||
// 新增表达式节点: 在此注册 handler, 编译器会警告缺失
|
||
ast_dispatch_set(&sema_dispatch, AST_IDENT_EXPR, analyze_ident_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_UNARY_EXPR, analyze_unary_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_BINARY_EXPR, analyze_binary_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_CALL_EXPR, analyze_call_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_FIELD_ACCESS, analyze_field_access_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_STRUCT_INIT, analyze_struct_init_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_ENUM_VARIANT, analyze_enum_variant_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_INDEX_EXPR, analyze_index_expr_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_METHOD_CALL, analyze_method_call_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_IF_STMT, analyze_node_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_BLOCK, analyze_node_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_LIST_COMP, analyze_list_comp_wrap);
|
||
ast_dispatch_set(&sema_dispatch, AST_LAMBDA, analyze_lambda_wrap);
|
||
}
|
||
|
||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
SemaCtx sctx = {scope, errors, a};
|
||
sema_dispatch.ctx = &sctx;
|
||
ast_visit(&sema_dispatch, node);
|
||
}
|