819 lines
38 KiB
C
819 lines
38 KiB
C
#include "sema.h"
|
||
#include <stdio.h>
|
||
#include <string.h>
|
||
|
||
// === 类型关系 ===
|
||
static TypeKind current_return_type = TYPE_VOID;
|
||
static const char* current_return_struct_name = NULL;
|
||
|
||
static TypeKind promote(TypeKind a, TypeKind b) {
|
||
// 枚举在算术运算中视为 i64
|
||
if (a == TYPE_ENUM) a = TYPE_I64;
|
||
if (b == TYPE_ENUM) b = TYPE_I64;
|
||
if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64;
|
||
if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64;
|
||
if (a == TYPE_BOOL || b == TYPE_BOOL) return TYPE_BOOL;
|
||
return TYPE_ERROR;
|
||
}
|
||
|
||
static bool is_numeric(TypeKind t) { return t == TYPE_I64 || t == TYPE_F64 || t == TYPE_ENUM; }
|
||
static bool is_comparable(TypeKind a, TypeKind b) {
|
||
if (a == b) return true;
|
||
// 枚举可以参与整数比较
|
||
if ((a == TYPE_I64 && b == TYPE_ENUM) || (a == TYPE_ENUM && b == TYPE_I64)) return true;
|
||
return false;
|
||
}
|
||
|
||
// === 向前声明 ===
|
||
static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
|
||
|
||
// === 检查表达式 ===
|
||
static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
switch (node->kind) {
|
||
case AST_LITERAL_EXPR:
|
||
break; // 类型已在创建时设置
|
||
|
||
case AST_IDENT_EXPR: {
|
||
Symbol* sym = scope_lookup(scope, node->as.ident.name);
|
||
if (!sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的变量 '%s'", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else if (sym->is_type_alias) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 是类型别名,不能作为表达式使用", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else if (sym->kind == SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 是函数,不能作为表达式使用", node->as.ident.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = sym->type;
|
||
if (sym->type == TYPE_STRUCT && sym->struct_type_name) {
|
||
node->type.struct_name = sym->struct_type_name;
|
||
}
|
||
if (sym->type == TYPE_ARRAY) {
|
||
node->type.element_type = sym->array_element_type;
|
||
node->type.element_struct_name = sym->array_element_struct_name;
|
||
node->type.array_size = sym->array_size;
|
||
}
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_UNARY_EXPR: {
|
||
analyze_expr(node->as.unary.operand, scope, errors, a);
|
||
TypeKind inner = node->as.unary.operand->type.kind;
|
||
if (node->as.unary.op == OP_NEG) {
|
||
if (!is_numeric(inner)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"一元 '-' 只能用于数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = inner;
|
||
}
|
||
} else { // OP_NOT
|
||
if (inner != TYPE_BOOL) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'!' 只能用于布尔类型,得到 '%s'", type_name(inner));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_BINARY_EXPR: {
|
||
analyze_expr(node->as.binary.left, scope, errors, a);
|
||
analyze_expr(node->as.binary.right, scope, errors, a);
|
||
TypeKind l = node->as.binary.left->type.kind;
|
||
TypeKind r = node->as.binary.right->type.kind;
|
||
if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; break; }
|
||
|
||
switch (node->as.binary.op) {
|
||
case OP_ADD:
|
||
if (l == TYPE_STR || r == TYPE_STR) {
|
||
// 字符串拼接:两边都必须是 str 类型
|
||
if (l != TYPE_STR || r != TYPE_STR) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'",
|
||
type_name(l), type_name(r));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_STR;
|
||
}
|
||
} else if (!is_numeric(l) || !is_numeric(r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"算术运算需要数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = promote(l, r);
|
||
}
|
||
break;
|
||
case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD:
|
||
if (!is_numeric(l) || !is_numeric(r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"算术运算需要数值类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = promote(l, r);
|
||
}
|
||
break;
|
||
case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE:
|
||
if (!is_comparable(l, r)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r));
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
break;
|
||
case OP_AND: case OP_OR:
|
||
if (l != TYPE_BOOL || r != TYPE_BOOL) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"逻辑运算需要布尔类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
} else {
|
||
node->type.kind = TYPE_BOOL;
|
||
}
|
||
break;
|
||
default: break;
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_CALL_EXPR: {
|
||
Symbol* sym = scope_lookup(scope, node->as.call.name);
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的函数 '%s'", node->as.call.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
// 即使函数未定义,也要分析参数表达式(它们可能有更多错误)
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
}
|
||
break;
|
||
}
|
||
if (node->as.call.arg_count != sym->param_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"函数 '%s' 需要 %zu 个参数,但提供了 %zu 个",
|
||
node->as.call.name, sym->param_count, node->as.call.arg_count);
|
||
node->type.kind = TYPE_ERROR;
|
||
// 即使参数数量不匹配,也分析已有的参数
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
}
|
||
break;
|
||
}
|
||
for (size_t i = 0; i < node->as.call.arg_count; i++) {
|
||
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||
TypeKind actual = node->as.call.args[i]->type.kind;
|
||
TypeKind expected = sym->param_types[i];
|
||
if (actual != TYPE_ERROR) {
|
||
if (expected == TYPE_STRUCT) {
|
||
// 结构体参数:比较具体类型名
|
||
const char* actual_name = node->as.call.args[i]->type.struct_name;
|
||
const char* expected_name = sym->param_struct_names ? sym->param_struct_names[i] : NULL;
|
||
if (actual != TYPE_STRUCT || !actual_name || !expected_name ||
|
||
strcmp(actual_name, expected_name) != 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
i + 1,
|
||
expected_name ? expected_name : "struct",
|
||
actual_name ? actual_name : type_name(actual));
|
||
}
|
||
} else if (actual != expected &&
|
||
!(expected == TYPE_I64 && actual == TYPE_ENUM)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
i + 1, type_name(expected), type_name(actual));
|
||
}
|
||
}
|
||
}
|
||
node->type.kind = sym->return_type;
|
||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) {
|
||
node->type.struct_name = sym->return_struct_type_name;
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_FIELD_ACCESS: {
|
||
analyze_expr(node->as.field_access.object, scope, errors, a);
|
||
AstNode* obj = node->as.field_access.object;
|
||
if (obj->type.kind == TYPE_ERROR) {
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
if (obj->type.kind != TYPE_STRUCT) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 不是结构体,不能访问字段 '%s'",
|
||
type_name(obj->type.kind), node->as.field_access.field);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
// 查找结构体定义
|
||
const char* struct_name = obj->type.struct_name;
|
||
if (!struct_name) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"无法确定结构体类型");
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
Symbol* struct_sym = scope_lookup_struct(scope, struct_name);
|
||
if (!struct_sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的结构体 '%s'", struct_name);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
int fi = scope_struct_field_index(struct_sym, node->as.field_access.field);
|
||
if (fi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
node->type.kind = struct_sym->struct_field_types[fi];
|
||
node->as.field_access.field_index = fi;
|
||
// 如果字段也是结构体类型,传播类型名
|
||
if (node->type.kind == TYPE_STRUCT &&
|
||
struct_sym->struct_field_struct_names &&
|
||
struct_sym->struct_field_struct_names[fi]) {
|
||
node->type.struct_name = struct_sym->struct_field_struct_names[fi];
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_STRUCT_INIT: {
|
||
const char* resolved_type_name = node->as.struct_init.type_name;
|
||
Symbol* struct_sym = scope_lookup_struct(scope, resolved_type_name);
|
||
if (!struct_sym) {
|
||
// 检查是否是类型别名指向结构体
|
||
Symbol* alias_sym = scope_lookup(scope, resolved_type_name);
|
||
if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) {
|
||
resolved_type_name = alias_sym->struct_type_name;
|
||
struct_sym = scope_lookup_struct(scope, resolved_type_name);
|
||
// 更新 type_name 为真实结构体名(codegen 需要)
|
||
node->as.struct_init.type_name = resolved_type_name;
|
||
}
|
||
}
|
||
if (!struct_sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的结构体类型 '%s'", node->as.struct_init.type_name);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
if (node->as.struct_init.field_count != struct_sym->struct_field_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 有 %zu 个字段,但提供了 %zu 个",
|
||
node->as.struct_init.type_name,
|
||
struct_sym->struct_field_count,
|
||
node->as.struct_init.field_count);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
// 检查每个字段名和类型匹配
|
||
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
|
||
const char* fname = node->as.struct_init.field_names[i];
|
||
AstNode* fval = node->as.struct_init.field_values[i];
|
||
analyze_expr(fval, scope, errors, a);
|
||
|
||
int fi = scope_struct_field_index(struct_sym, fname);
|
||
if (fi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有字段 '%s'",
|
||
node->as.struct_init.type_name, fname);
|
||
node->type.kind = TYPE_ERROR;
|
||
continue;
|
||
}
|
||
TypeKind expected = struct_sym->struct_field_types[fi];
|
||
TypeKind actual = fval->type.kind;
|
||
if (actual != TYPE_ERROR && actual != expected) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'",
|
||
fname, type_name(expected), type_name(actual));
|
||
}
|
||
}
|
||
if (node->type.kind != TYPE_ERROR) {
|
||
node->type.kind = TYPE_STRUCT;
|
||
node->type.struct_name = resolved_type_name;
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_ENUM_VARIANT: {
|
||
Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name);
|
||
if (!enum_sym || enum_sym->kind != SYM_ENUM) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的枚举 '%s'", node->as.enum_variant.enum_name);
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name);
|
||
if (vi < 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"枚举 '%s' 没有变体 '%s'",
|
||
node->as.enum_variant.enum_name,
|
||
node->as.enum_variant.variant_name);
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
node->as.enum_variant.variant_index = vi;
|
||
node->type.kind = TYPE_ENUM;
|
||
break;
|
||
}
|
||
|
||
case AST_INDEX_EXPR: {
|
||
analyze_expr(node->as.index_expr.array, scope, errors, a);
|
||
analyze_expr(node->as.index_expr.index, scope, errors, a);
|
||
AstNode* arr = node->as.index_expr.array;
|
||
AstNode* idx = node->as.index_expr.index;
|
||
|
||
if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; break; }
|
||
if (arr->type.kind != TYPE_ARRAY) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"类型 '%s' 不支持索引操作", type_name(arr->type.kind));
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; break; }
|
||
if (idx->type.kind != TYPE_I64) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
// 结果类型 = 元素类型
|
||
node->type.kind = arr->type.element_type;
|
||
node->type.struct_name = arr->type.element_struct_name;
|
||
break;
|
||
}
|
||
|
||
case AST_METHOD_CALL: {
|
||
analyze_expr(node->as.method_call.receiver, scope, errors, a);
|
||
const char* recv_struct = node->as.method_call.receiver->type.struct_name;
|
||
if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"只有结构体类型支持方法调用");
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
// 构造改名后的函数名并查找
|
||
char mangled[256];
|
||
snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct,
|
||
node->as.method_call.method_name);
|
||
Symbol* sym = scope_lookup(scope, mangled);
|
||
if (!sym || sym->kind != SYM_FUNCTION) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"结构体 '%s' 没有方法 '%s'", recv_struct,
|
||
node->as.method_call.method_name);
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
// 检查参数数量(用户提供的参数 + 隐含的 self)
|
||
if (node->as.method_call.arg_count + 1 != sym->param_count) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"方法 '%s' 需要 %zu 个参数,提供了 %zu 个",
|
||
node->as.method_call.method_name,
|
||
sym->param_count > 0 ? sym->param_count - 1 : 0,
|
||
node->as.method_call.arg_count);
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
// 对每个参数进行类型检查(跳过 self 参数,即 sym->param_types[0] 是 self 的类型)
|
||
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
|
||
analyze_expr(node->as.method_call.args[i], scope, errors, a);
|
||
TypeKind actual = node->as.method_call.args[i]->type.kind;
|
||
TypeKind expected = sym->param_types[i + 1];
|
||
if (actual != TYPE_ERROR && actual != expected &&
|
||
!(expected == TYPE_I64 && actual == TYPE_ENUM)) {
|
||
if (expected == TYPE_STRUCT) {
|
||
// 结构体类型参数:比较具体类型名
|
||
const char* actual_name = node->as.method_call.args[i]->type.struct_name;
|
||
const char* expected_name = sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL;
|
||
if (!actual_name || !expected_name || strcmp(actual_name, expected_name) != 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
i + 1,
|
||
expected_name ? expected_name : "struct",
|
||
actual_name ? actual_name : type_name(actual));
|
||
}
|
||
} else {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
|
||
i + 1, type_name(expected), type_name(actual));
|
||
}
|
||
}
|
||
}
|
||
node->type.kind = sym->return_type;
|
||
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) {
|
||
node->type.struct_name = sym->return_struct_type_name;
|
||
}
|
||
break;
|
||
}
|
||
|
||
default: break;
|
||
}
|
||
}
|
||
|
||
static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||
if (!node) return;
|
||
|
||
switch (node->kind) {
|
||
case AST_PROGRAM:
|
||
// Pass 0: 注册类型别名
|
||
for (size_t i = 0; i < node->as.program.alias_count; i++) {
|
||
AstNode* alias = node->as.program.type_aliases[i];
|
||
Symbol* sym = scope_insert(scope, a, alias->as.type_alias.name,
|
||
SYM_VARIABLE, alias->as.type_alias.aliased_type);
|
||
if (sym) {
|
||
sym->is_type_alias = true;
|
||
if (alias->as.type_alias.aliased_struct_name) {
|
||
sym->struct_type_name = alias->as.type_alias.aliased_struct_name;
|
||
}
|
||
}
|
||
}
|
||
// 注册枚举
|
||
for (size_t i = 0; i < node->as.program.enum_count; i++) {
|
||
AstNode* ed = node->as.program.enums[i];
|
||
scope_insert_enum(scope, a, ed->as.enum_decl.name,
|
||
ed->as.enum_decl.variants,
|
||
ed->as.enum_decl.variant_count);
|
||
}
|
||
// 第一遍:收集所有结构体定义
|
||
for (size_t i = 0; i < node->as.program.struct_count; i++) {
|
||
AstNode* sd = node->as.program.structs[i];
|
||
const char** fnames = (const char**)arena_alloc_impl(a,
|
||
sd->as.struct_decl.field_count * sizeof(const char*));
|
||
TypeKind* ftypes = (TypeKind*)arena_alloc_impl(a,
|
||
sd->as.struct_decl.field_count * sizeof(TypeKind));
|
||
const char** fstruct_names = (const char**)arena_alloc_impl(a,
|
||
sd->as.struct_decl.field_count * sizeof(const char*));
|
||
for (size_t j = 0; j < sd->as.struct_decl.field_count; j++) {
|
||
fnames[j] = sd->as.struct_decl.fields[j]->as.parameter.name;
|
||
ftypes[j] = sd->as.struct_decl.fields[j]->as.parameter.type;
|
||
fstruct_names[j] = sd->as.struct_decl.fields[j]->as.parameter.struct_type_name;
|
||
}
|
||
scope_insert_struct(scope, a, sd->as.struct_decl.name,
|
||
fnames, ftypes, fstruct_names,
|
||
sd->as.struct_decl.field_count);
|
||
}
|
||
// 处理 impl 块:将方法名改写为 StructName$methodName,
|
||
// 并自动添加 self 参数(第一个参数),然后注册为普通函数。
|
||
// 同时将改写后的方法追加到程序 functions 数组方便后续 codegen。
|
||
{
|
||
// 先统计需要新增多少个函数(impl 中的方法总数)
|
||
size_t extra_fn = 0;
|
||
for (size_t i = 0; i < node->as.program.impl_count; i++) {
|
||
AstNode* impl = node->as.program.impls[i];
|
||
extra_fn += impl->as.impl_block.method_count;
|
||
}
|
||
if (extra_fn > 0) {
|
||
AstNode** new_fns = (AstNode**)arena_alloc_impl(a,
|
||
(node->as.program.fn_count + extra_fn) * sizeof(AstNode*));
|
||
memcpy(new_fns, node->as.program.functions,
|
||
node->as.program.fn_count * sizeof(AstNode*));
|
||
size_t write_pos = node->as.program.fn_count;
|
||
|
||
for (size_t i = 0; i < node->as.program.impl_count; i++) {
|
||
AstNode* impl = node->as.program.impls[i];
|
||
const char* st_name = impl->as.impl_block.struct_name;
|
||
// 验证目标结构体存在
|
||
Symbol* st_sym = scope_lookup_struct(scope, st_name);
|
||
if (!st_sym) {
|
||
error_add(errors, "<sema>", impl->loc.line, impl->loc.col,
|
||
"impl 的目标结构体 '%s' 未定义", st_name);
|
||
continue;
|
||
}
|
||
for (size_t j = 0; j < impl->as.impl_block.method_count; j++) {
|
||
AstNode* method = impl->as.impl_block.methods[j];
|
||
// 构造改名后的函数名
|
||
char mangled[256];
|
||
snprintf(mangled, sizeof(mangled), "%s$%s", st_name,
|
||
method->as.function.name);
|
||
method->as.function.name = arena_strdup_impl(a, mangled,
|
||
strlen(mangled));
|
||
// 追加到新 functions 数组
|
||
new_fns[write_pos++] = method;
|
||
}
|
||
}
|
||
// 更新程序节点
|
||
node->as.program.functions = new_fns;
|
||
node->as.program.fn_count = node->as.program.fn_count + extra_fn;
|
||
}
|
||
}
|
||
|
||
// 第二遍:收集所有函数签名
|
||
for (size_t i = 0; i < node->as.program.fn_count; i++) {
|
||
AstNode* fn = node->as.program.functions[i];
|
||
TypeKind* pts = (TypeKind*)arena_alloc_impl(a, fn->as.function.param_count * sizeof(TypeKind));
|
||
const char** pstruct_names = (const char**)arena_alloc_impl(a, fn->as.function.param_count * sizeof(const char*));
|
||
for (size_t j = 0; j < fn->as.function.param_count; j++) {
|
||
TypeKind pt = fn->as.function.params[j]->as.parameter.type;
|
||
const char* psn = fn->as.function.params[j]->as.parameter.struct_type_name;
|
||
// 解析参数类型的别名
|
||
if (psn) {
|
||
Symbol* as = scope_lookup(scope, psn);
|
||
if (as && as->is_type_alias) {
|
||
pt = as->type;
|
||
psn = as->struct_type_name;
|
||
}
|
||
}
|
||
pts[j] = pt;
|
||
pstruct_names[j] = psn;
|
||
}
|
||
// 解析返回类型的别名
|
||
TypeKind ret_t = fn->as.function.return_type;
|
||
const char* ret_sn = fn->as.function.return_struct_type_name;
|
||
if (ret_sn) {
|
||
Symbol* as = scope_lookup(scope, ret_sn);
|
||
if (as && as->is_type_alias) {
|
||
ret_t = as->type;
|
||
ret_sn = as->struct_type_name;
|
||
}
|
||
}
|
||
scope_insert_function(scope, a, fn->as.function.name,
|
||
ret_t, ret_sn,
|
||
pts, pstruct_names,
|
||
fn->as.function.param_count);
|
||
}
|
||
// 第三遍:分析每个函数体
|
||
for (size_t i = 0; i < node->as.program.fn_count; i++) {
|
||
analyze_node(node->as.program.functions[i], scope, errors, a);
|
||
}
|
||
break;
|
||
|
||
case AST_FUNCTION: {
|
||
Scope* fn_scope = scope_new(a, scope);
|
||
// 注册参数(同时解析类型别名,更新 AST 节点供 codegen 使用)
|
||
for (size_t i = 0; i < node->as.function.param_count; i++) {
|
||
AstNode* p = node->as.function.params[i];
|
||
TypeKind pt = p->as.parameter.type;
|
||
const char* psn = p->as.parameter.struct_type_name;
|
||
if (psn) {
|
||
Symbol* as = scope_lookup(scope, psn);
|
||
if (as && as->is_type_alias) {
|
||
pt = as->type;
|
||
psn = as->struct_type_name;
|
||
// 更新 AST 节点
|
||
p->as.parameter.type = pt;
|
||
p->as.parameter.struct_type_name = psn;
|
||
}
|
||
}
|
||
Symbol* sym = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, pt);
|
||
if (sym && pt == TYPE_STRUCT && psn) {
|
||
sym->struct_type_name = psn;
|
||
}
|
||
}
|
||
// 解析返回类型的别名
|
||
const char* ret_sn = node->as.function.return_struct_type_name;
|
||
if (ret_sn) {
|
||
Symbol* as = scope_lookup(scope, ret_sn);
|
||
if (as && as->is_type_alias) {
|
||
node->as.function.return_type = as->type;
|
||
node->as.function.return_struct_type_name = as->struct_type_name;
|
||
}
|
||
}
|
||
TypeKind saved = current_return_type;
|
||
const char* saved_name = current_return_struct_name;
|
||
current_return_type = node->as.function.return_type;
|
||
current_return_struct_name = node->as.function.return_struct_type_name;
|
||
analyze_node(node->as.function.body, fn_scope, errors, a);
|
||
current_return_type = saved;
|
||
current_return_struct_name = saved_name;
|
||
break;
|
||
}
|
||
|
||
case AST_BLOCK:
|
||
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
|
||
analyze_node(node->as.block.stmts[i], scope, errors, a);
|
||
}
|
||
break;
|
||
|
||
case AST_LET_STMT: {
|
||
TypeKind var_type;
|
||
const char* var_struct_name = NULL;
|
||
bool is_array_type = false;
|
||
|
||
if (node->as.let_stmt.has_type_annot) {
|
||
if (node->as.let_stmt.annot_type == TYPE_ARRAY) {
|
||
// 数组类型标注: 跳过 init 分析 (init 是自引用的占位符)
|
||
is_array_type = true;
|
||
var_type = TYPE_ARRAY;
|
||
} else {
|
||
analyze_expr(node->as.let_stmt.init, scope, errors, a);
|
||
TypeKind inferred = node->as.let_stmt.init->type.kind;
|
||
const char* annot_struct = node->as.let_stmt.struct_type_name;
|
||
if (annot_struct) {
|
||
// 先检查是否是类型别名
|
||
Symbol* alias_sym = scope_lookup(scope, annot_struct);
|
||
if (alias_sym && alias_sym->is_type_alias) {
|
||
var_type = alias_sym->type;
|
||
var_struct_name = alias_sym->struct_type_name;
|
||
} else {
|
||
// struct 类型标注
|
||
Symbol* st_sym = scope_lookup_struct(scope, annot_struct);
|
||
if (!st_sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的类型 '%s'", annot_struct);
|
||
break;
|
||
}
|
||
var_type = TYPE_STRUCT;
|
||
var_struct_name = annot_struct;
|
||
}
|
||
} else {
|
||
var_type = node->as.let_stmt.annot_type;
|
||
}
|
||
if (inferred != TYPE_ERROR && inferred != var_type) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"变量 '%s' 类型标注为 '%s',但初始化表达式类型为 '%s'",
|
||
node->as.let_stmt.name,
|
||
annot_struct ? annot_struct : type_name(var_type),
|
||
type_name(inferred));
|
||
}
|
||
}
|
||
} else {
|
||
analyze_expr(node->as.let_stmt.init, scope, errors, a);
|
||
TypeKind inferred = node->as.let_stmt.init->type.kind;
|
||
// 类型推断
|
||
if (inferred == TYPE_ERROR || inferred == TYPE_VOID) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"无法从表达式推断变量 '%s' 的类型", node->as.let_stmt.name);
|
||
break;
|
||
}
|
||
var_type = inferred;
|
||
if (inferred == TYPE_STRUCT) {
|
||
var_struct_name = node->as.let_stmt.init->type.struct_name;
|
||
}
|
||
}
|
||
|
||
node->type.kind = var_type;
|
||
node->type.struct_name = var_struct_name;
|
||
if (is_array_type) {
|
||
node->type.element_type = node->as.let_stmt.annot_element_type;
|
||
node->type.element_struct_name = node->as.let_stmt.annot_element_struct_name;
|
||
node->type.array_size = node->as.let_stmt.annot_array_size;
|
||
}
|
||
Symbol* sym = scope_insert(scope, a, node->as.let_stmt.name, SYM_VARIABLE, var_type);
|
||
if (!sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"变量 '%s' 重复定义", node->as.let_stmt.name);
|
||
} else {
|
||
sym->is_mut = node->as.let_stmt.is_mut;
|
||
if (var_struct_name) {
|
||
sym->type = TYPE_STRUCT;
|
||
sym->struct_type_name = var_struct_name;
|
||
}
|
||
if (is_array_type) {
|
||
sym->array_element_type = node->type.element_type;
|
||
sym->array_element_struct_name = node->type.element_struct_name;
|
||
sym->array_size = node->type.array_size;
|
||
}
|
||
}
|
||
break;
|
||
}
|
||
|
||
case AST_ASSIGN_STMT: {
|
||
Symbol* sym = scope_lookup(scope, node->as.assign_stmt.name);
|
||
if (!sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的变量 '%s'", node->as.assign_stmt.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
if (sym->kind != SYM_VARIABLE) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 不是变量,不能赋值", node->as.assign_stmt.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
if (!sym->is_mut) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"不能对不可变变量 '%s' 赋值(需用 var 声明)",
|
||
node->as.assign_stmt.name);
|
||
node->type.kind = TYPE_ERROR;
|
||
break;
|
||
}
|
||
analyze_expr(node->as.assign_stmt.value, scope, errors, a);
|
||
TypeKind value_ty = node->as.assign_stmt.value->type.kind;
|
||
if (value_ty != TYPE_ERROR && value_ty != sym->type) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"赋值类型不匹配: 变量 '%s' 类型为 '%s',但表达式类型为 '%s'",
|
||
node->as.assign_stmt.name, type_name(sym->type), type_name(value_ty));
|
||
}
|
||
node->type.kind = TYPE_VOID;
|
||
break;
|
||
}
|
||
|
||
case AST_ARRAY_ASSIGN_STMT: {
|
||
Symbol* sym = scope_lookup(scope, node->as.array_assign.name);
|
||
if (!sym) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"未定义的变量 '%s'", node->as.array_assign.name);
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
if (sym->type != TYPE_ARRAY) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"'%s' 不是数组类型,不能使用索引赋值", node->as.array_assign.name);
|
||
node->type.kind = TYPE_ERROR; break;
|
||
}
|
||
analyze_expr(node->as.array_assign.index, scope, errors, a);
|
||
analyze_expr(node->as.array_assign.value, scope, errors, a);
|
||
AstNode* idx = node->as.array_assign.index;
|
||
AstNode* val = node->as.array_assign.value;
|
||
if (idx->type.kind != TYPE_ERROR && idx->type.kind != TYPE_I64) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
|
||
}
|
||
TypeKind elem_kind = (sym->type == TYPE_ARRAY) ? sym->array_element_type : sym->type;
|
||
if (val->type.kind != TYPE_ERROR && val->type.kind != elem_kind) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"数组元素类型不匹配: 期望 '%s',得到 '%s'",
|
||
type_name(elem_kind), type_name(val->type.kind));
|
||
}
|
||
// struct 元素类型时还需检查结构体名是否匹配
|
||
if (val->type.kind != TYPE_ERROR && elem_kind == TYPE_STRUCT
|
||
&& sym->array_element_struct_name && val->type.struct_name) {
|
||
if (strcmp(sym->array_element_struct_name, val->type.struct_name) != 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"数组元素类型不匹配: 期望 '%s',得到 '%s'",
|
||
sym->array_element_struct_name, val->type.struct_name);
|
||
}
|
||
}
|
||
node->type.kind = TYPE_VOID;
|
||
break;
|
||
}
|
||
|
||
case AST_IF_STMT:
|
||
analyze_expr(node->as.if_stmt.cond, scope, errors, a);
|
||
if (node->as.if_stmt.cond->type.kind != TYPE_BOOL &&
|
||
node->as.if_stmt.cond->type.kind != TYPE_ERROR) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "if 条件必须是布尔类型");
|
||
}
|
||
analyze_node(node->as.if_stmt.then_block, scope, errors, a);
|
||
if (node->as.if_stmt.else_block) {
|
||
analyze_node(node->as.if_stmt.else_block, scope, errors, a);
|
||
}
|
||
break;
|
||
|
||
case AST_WHILE_STMT:
|
||
analyze_expr(node->as.while_stmt.cond, scope, errors, a);
|
||
if (node->as.while_stmt.cond->type.kind != TYPE_BOOL &&
|
||
node->as.while_stmt.cond->type.kind != TYPE_ERROR) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col, "while 条件必须是布尔类型");
|
||
}
|
||
analyze_node(node->as.while_stmt.body, scope, errors, a);
|
||
break;
|
||
|
||
case AST_RETURN_STMT:
|
||
if (node->as.return_stmt.expr) {
|
||
analyze_expr(node->as.return_stmt.expr, scope, errors, a);
|
||
node->type.kind = node->as.return_stmt.expr->type.kind;
|
||
TypeKind actual = node->as.return_stmt.expr->type.kind;
|
||
TypeKind expected = current_return_type;
|
||
if (actual != TYPE_ERROR && expected != TYPE_VOID) {
|
||
if (expected == TYPE_STRUCT) {
|
||
// 结构体返回类型:比较具体类型名
|
||
const char* actual_name = node->as.return_stmt.expr->type.struct_name;
|
||
const char* expected_name = current_return_struct_name;
|
||
if (actual != TYPE_STRUCT || !actual_name || !expected_name ||
|
||
strcmp(actual_name, expected_name) != 0) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"返回类型不匹配: 期望 '%s',得到 '%s'",
|
||
expected_name ? expected_name : "struct",
|
||
actual_name ? actual_name : type_name(actual));
|
||
}
|
||
} else if (actual != expected &&
|
||
!(expected == TYPE_I64 && actual == TYPE_ENUM)) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"返回类型不匹配: 期望 '%s',得到 '%s'",
|
||
type_name(expected), type_name(actual));
|
||
}
|
||
}
|
||
} else if (current_return_type != TYPE_VOID) {
|
||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||
"函数应返回值类型 '%s'",
|
||
current_return_struct_name ? current_return_struct_name : type_name(current_return_type));
|
||
}
|
||
break;
|
||
|
||
case AST_EXPR_STMT:
|
||
analyze_expr(node->as.expr_stmt.expr, scope, errors, a);
|
||
break;
|
||
|
||
default:
|
||
analyze_expr(node, scope, errors, a);
|
||
break;
|
||
}
|
||
}
|
||
|
||
void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) {
|
||
Scope* global_scope = scope_new(arena, NULL);
|
||
|
||
// 注册内置函数
|
||
TypeKind params_i64[] = {TYPE_I64};
|
||
scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, NULL, params_i64, NULL, 1);
|
||
TypeKind params_f64[] = {TYPE_F64};
|
||
scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, NULL, params_f64, NULL, 1);
|
||
TypeKind params_bool[] = {TYPE_BOOL};
|
||
scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, NULL, params_bool, NULL, 1);
|
||
TypeKind params_str[] = {TYPE_STR};
|
||
scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, 1);
|
||
|
||
analyze_node(ast, global_scope, errors, arena);
|
||
}
|