Files
l-language/src/sema/sema.c
T

918 lines
43 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "sema.h"
#include <stdio.h>
#include <string.h>
// === 类型关系 ===
static TypeKind current_return_type = TYPE_VOID;
static const char* current_return_struct_name = NULL;
static TypeKind promote(TypeKind a, TypeKind b) {
// 枚举在算术运算中视为 i64
if (a == TYPE_ENUM) a = TYPE_I64;
if (b == TYPE_ENUM) b = TYPE_I64;
// char 在算术中提升为 i32
if (a == TYPE_CHAR) a = TYPE_I32;
if (b == TYPE_CHAR) b = TYPE_I32;
if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64;
if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64;
if (a == TYPE_U64 || b == TYPE_U64) return TYPE_U64;
if (a == TYPE_I32 || b == TYPE_I32) return TYPE_I32;
if (a == TYPE_BOOL || b == TYPE_BOOL) return TYPE_BOOL;
return TYPE_ERROR;
}
static bool is_numeric(TypeKind t) {
return t == TYPE_I32 || t == TYPE_I64 || t == TYPE_U64
|| t == TYPE_F64 || t == TYPE_CHAR || t == TYPE_ENUM;
}
// 隐式类型转换规则: 无损加宽允许,有符号→无符号不允许
static bool can_implicit_convert(TypeKind from, TypeKind to) {
if (from == to) return true;
// 枚举视为 i64
if (from == TYPE_ENUM) from = TYPE_I64;
if (to == TYPE_ENUM) to = TYPE_I64;
// char 可转为任意整数
if (from == TYPE_CHAR) return to == TYPE_I32 || to == TYPE_I64 || to == TYPE_U64 || to == TYPE_F64;
// i32 可加宽
if (from == TYPE_I32) return to == TYPE_I64 || to == TYPE_F64;
// i64 可转 f64
if (from == TYPE_I64) return to == TYPE_F64;
// u64 ↔ i64 双向允许(同一位宽,LLVM 同类型)
if (from == TYPE_U64) return to == TYPE_F64 || to == TYPE_I64;
if (from == TYPE_I64) return to == TYPE_F64 || to == TYPE_U64;
return false;
}
static bool is_comparable(TypeKind a, TypeKind b) {
if (a == b) return true;
// 枚举可以参与整数比较
if ((a == TYPE_I64 && b == TYPE_ENUM) || (a == TYPE_ENUM && b == TYPE_I64)) return true;
return false;
}
// === 向前声明 ===
static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
// === 表达式类型检查辅助函数 ===
static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a);
static void analyze_ident_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
(void)a;
Symbol* sym = scope_lookup(scope, node->as.ident.name);
if (!sym) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的变量 '%s'", node->as.ident.name);
node->type.kind = TYPE_ERROR;
} else if (sym->is_type_alias) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'%s' 是类型别名,不能作为表达式使用", node->as.ident.name);
node->type.kind = TYPE_ERROR;
} else if (sym->kind == SYM_FUNCTION) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'%s' 是函数,不能作为表达式使用", node->as.ident.name);
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = sym->type;
if (sym->type == TYPE_STRUCT && sym->struct_type_name)
node->type.struct_name = sym->struct_type_name;
if (sym->type == TYPE_ARRAY) {
node->type.element_type = sym->array_element_type;
node->type.element_struct_name = sym->array_element_struct_name;
node->type.array_size = sym->array_size;
}
}
}
static void analyze_unary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
analyze_expr(node->as.unary.operand, scope, errors, a);
TypeKind inner = node->as.unary.operand->type.kind;
if (node->as.unary.op == OP_NEG) {
if (!is_numeric(inner)) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"一元 '-' 只能用于数值类型");
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = inner;
}
} else { // OP_NOT
if (inner != TYPE_BOOL) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'!' 只能用于布尔类型,得到 '%s'", type_name(inner));
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = TYPE_BOOL;
}
}
}
static void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
analyze_expr(node->as.binary.left, scope, errors, a);
analyze_expr(node->as.binary.right, scope, errors, a);
TypeKind l = node->as.binary.left->type.kind;
TypeKind r = node->as.binary.right->type.kind;
if (l == TYPE_ERROR || r == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
switch (node->as.binary.op) {
case OP_ADD:
if (l == TYPE_STR || r == TYPE_STR) {
if (l != TYPE_STR || r != TYPE_STR) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'",
type_name(l), type_name(r));
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = TYPE_STR;
}
} else if (!is_numeric(l) || !is_numeric(r)) {
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = promote(l, r);
}
break;
case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD:
if (!is_numeric(l) || !is_numeric(r)) {
error_add(errors, "<sema>", node->loc.line, node->loc.col, "算术运算需要数值类型");
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = promote(l, r);
}
break;
case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE:
if (!is_comparable(l, r)) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r));
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = TYPE_BOOL;
}
break;
case OP_AND: case OP_OR:
if (l != TYPE_BOOL || r != TYPE_BOOL) {
error_add(errors, "<sema>", node->loc.line, node->loc.col, "逻辑运算需要布尔类型");
node->type.kind = TYPE_ERROR;
} else {
node->type.kind = TYPE_BOOL;
}
break;
default: break;
}
}
// 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用)
static bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname,
size_t idx, AstNode* call_node, Symbol* fn_sym,
ErrorList* errors, Arena* a) {
(void)a;
TypeKind actual = arg->type.kind;
if (actual == TYPE_ERROR) return false;
if (expected == TYPE_STRUCT && expected_sname) {
// 检查是否是泛型类型参数(匹配则接受任意类型)
if (fn_sym && fn_sym->type_params) {
for (size_t t = 0; t < fn_sym->type_param_count; t++) {
if (strcmp(expected_sname, fn_sym->type_params[t]) == 0)
return true; // 泛型参数,接受任意类型
}
}
const char* actual_name = arg->type.struct_name;
if (actual != TYPE_STRUCT || !actual_name ||
strcmp(actual_name, expected_sname) != 0) {
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
idx + 1, expected_sname ? expected_sname : "struct",
actual_name ? actual_name : type_name(actual));
}
return false;
}
if (actual == expected) return false;
if (expected == TYPE_I64 && actual == TYPE_ENUM) return false;
if (can_implicit_convert(actual, expected)) return false;
if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR
&& (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false;
error_add(errors, "<sema>", call_node->loc.line, call_node->loc.col,
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
idx + 1, type_name(expected), type_name(actual));
return false;
}
// 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用)
static bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
ErrorList* errors, const char* call_name) {
AstNode** args = node->as.call.args;
const char** arg_names = node->as.call.arg_names;
size_t arg_count = node->as.call.arg_count;
if (!arg_names) return true;
AstNode* reordered[16] = {0};
for (size_t i = 0; i < arg_count; i++) {
if (arg_names[i]) {
bool found = false;
for (size_t j = param_offset; j < sym->param_count; j++) {
if (sym->param_names && sym->param_names[j] &&
strcmp(arg_names[i], sym->param_names[j]) == 0) {
reordered[j - param_offset] = args[i];
found = true; break;
}
}
if (!found) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'%s' 没有名为 '%s' 的参数", call_name, arg_names[i]);
return false;
}
} else {
reordered[i] = args[i];
}
}
memcpy(args, reordered, arg_count * sizeof(AstNode*));
return true;
}
static void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
Symbol* sym = scope_lookup(scope, node->as.call.name);
if (!sym || sym->kind != SYM_FUNCTION) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的函数 '%s'", node->as.call.name);
node->type.kind = TYPE_ERROR;
for (size_t i = 0; i < node->as.call.arg_count; i++)
analyze_expr(node->as.call.args[i], scope, errors, a);
return;
}
if (node->as.call.arg_count != sym->param_count) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"函数 '%s' 需要 %zu 个参数,但提供了 %zu 个",
node->as.call.name, sym->param_count, node->as.call.arg_count);
node->type.kind = TYPE_ERROR;
for (size_t i = 0; i < node->as.call.arg_count; i++)
analyze_expr(node->as.call.args[i], scope, errors, a);
return;
}
if (!reorder_named_args(node, sym, 0, errors, node->as.call.name)) {
node->type.kind = TYPE_ERROR; return;
}
for (size_t i = 0; i < node->as.call.arg_count; i++) {
analyze_expr(node->as.call.args[i], scope, errors, a);
bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i],
sym->param_struct_names ? sym->param_struct_names[i] : NULL,
i, node, sym, errors, a);
// 泛型: 若实参匹配类型参数,传播具体类型到返回值
if (is_generic_param && sym->return_type == TYPE_STRUCT
&& sym->return_struct_type_name && sym->type_params) {
for (size_t t = 0; t < sym->type_param_count; t++) {
if (strcmp(sym->return_struct_type_name, sym->type_params[t]) == 0) {
node->type.kind = node->as.call.args[i]->type.kind;
node->type.struct_name = node->as.call.args[i]->type.struct_name;
return;
}
}
}
}
node->type.kind = sym->return_type;
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
node->type.struct_name = sym->return_struct_type_name;
}
static void analyze_field_access(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
analyze_expr(node->as.field_access.object, scope, errors, a);
AstNode* obj = node->as.field_access.object;
if (obj->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
if (obj->type.kind != TYPE_STRUCT) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"类型 '%s' 不是结构体,不能访问字段 '%s'",
type_name(obj->type.kind), node->as.field_access.field);
node->type.kind = TYPE_ERROR; return;
}
const char* struct_name = obj->type.struct_name;
if (!struct_name) {
error_add(errors, "<sema>", node->loc.line, node->loc.col, "无法确定结构体类型");
node->type.kind = TYPE_ERROR; return;
}
Symbol* struct_sym = scope_lookup_struct(scope, struct_name);
if (!struct_sym) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的结构体 '%s'", struct_name);
node->type.kind = TYPE_ERROR; return;
}
int fi = scope_struct_field_index(struct_sym, node->as.field_access.field);
if (fi < 0) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field);
node->type.kind = TYPE_ERROR; return;
}
node->type.kind = struct_sym->struct_field_types[fi];
node->as.field_access.field_index = fi;
if (node->type.kind == TYPE_STRUCT && struct_sym->struct_field_struct_names &&
struct_sym->struct_field_struct_names[fi])
node->type.struct_name = struct_sym->struct_field_struct_names[fi];
}
static void analyze_struct_init(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
const char* resolved = node->as.struct_init.type_name;
Symbol* struct_sym = scope_lookup_struct(scope, resolved);
if (!struct_sym) {
Symbol* alias_sym = scope_lookup(scope, resolved);
if (alias_sym && alias_sym->is_type_alias && alias_sym->struct_type_name) {
resolved = alias_sym->struct_type_name;
struct_sym = scope_lookup_struct(scope, resolved);
node->as.struct_init.type_name = resolved;
}
}
if (!struct_sym) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的结构体类型 '%s'", node->as.struct_init.type_name);
node->type.kind = TYPE_ERROR; return;
}
if (node->as.struct_init.field_count != struct_sym->struct_field_count) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 有 %zu 个字段,但提供了 %zu 个",
node->as.struct_init.type_name,
struct_sym->struct_field_count, node->as.struct_init.field_count);
node->type.kind = TYPE_ERROR; return;
}
for (size_t i = 0; i < node->as.struct_init.field_count; i++) {
const char* fname = node->as.struct_init.field_names[i];
AstNode* fval = node->as.struct_init.field_values[i];
analyze_expr(fval, scope, errors, a);
int fi = scope_struct_field_index(struct_sym, fname);
if (fi < 0) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname);
node->type.kind = TYPE_ERROR; continue;
}
TypeKind expected = struct_sym->struct_field_types[fi];
TypeKind actual = fval->type.kind;
if (actual != TYPE_ERROR && actual != expected)
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'",
fname, type_name(expected), type_name(actual));
}
if (node->type.kind != TYPE_ERROR) {
node->type.kind = TYPE_STRUCT;
node->type.struct_name = resolved;
}
}
static void analyze_enum_variant(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
(void)a;
Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name);
if (!enum_sym || enum_sym->kind != SYM_ENUM) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的枚举 '%s'", node->as.enum_variant.enum_name);
node->type.kind = TYPE_ERROR; return;
}
int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name);
if (vi < 0) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"枚举 '%s' 没有变体 '%s'",
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
node->type.kind = TYPE_ERROR; return;
}
node->as.enum_variant.variant_index = vi;
// ADT: 检查 payload
TypeKind expected_pt = TYPE_VOID;
if (enum_sym->variant_payload_types)
expected_pt = enum_sym->variant_payload_types[vi];
if (node->as.enum_variant.payload) {
if (expected_pt == TYPE_VOID && enum_sym->variant_payload_types) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"枚举变体 '%s::%s' 不接受 payload",
node->as.enum_variant.enum_name, node->as.enum_variant.variant_name);
node->type.kind = TYPE_ERROR; return;
}
analyze_expr(node->as.enum_variant.payload, scope, errors, a);
TypeKind actual = node->as.enum_variant.payload->type.kind;
if (actual != TYPE_ERROR && actual != expected_pt) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"枚举变体 payload 类型不匹配: 期望 '%s',得到 '%s'",
type_name(expected_pt), type_name(actual));
}
}
node->type.kind = TYPE_ENUM;
}
static void analyze_index_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
analyze_expr(node->as.index_expr.array, scope, errors, a);
analyze_expr(node->as.index_expr.index, scope, errors, a);
AstNode* arr = node->as.index_expr.array;
AstNode* idx = node->as.index_expr.index;
if (arr->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
if (arr->type.kind != TYPE_ARRAY) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"类型 '%s' 不支持索引操作", type_name(arr->type.kind));
node->type.kind = TYPE_ERROR; return;
}
if (idx->type.kind == TYPE_ERROR) { node->type.kind = TYPE_ERROR; return; }
if (idx->type.kind != TYPE_I64) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
node->type.kind = TYPE_ERROR; return;
}
node->type.kind = arr->type.element_type;
node->type.struct_name = arr->type.element_struct_name;
}
static void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
analyze_expr(node->as.method_call.receiver, scope, errors, a);
const char* recv_struct = node->as.method_call.receiver->type.struct_name;
if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"只有结构体类型支持方法调用");
node->type.kind = TYPE_ERROR; return;
}
char mangled[256];
snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct,
node->as.method_call.method_name);
Symbol* sym = scope_lookup(scope, mangled);
if (!sym || sym->kind != SYM_FUNCTION) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 没有方法 '%s'", recv_struct,
node->as.method_call.method_name);
node->type.kind = TYPE_ERROR; return;
}
if (node->as.method_call.arg_count + 1 != sym->param_count) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"方法 '%s' 需要 %zu 个参数,提供了 %zu 个",
node->as.method_call.method_name,
sym->param_count > 0 ? sym->param_count - 1 : 0,
node->as.method_call.arg_count);
node->type.kind = TYPE_ERROR; return;
}
if (!reorder_named_args(node, sym, 1, errors, node->as.method_call.method_name)) {
node->type.kind = TYPE_ERROR; return;
}
for (size_t i = 0; i < node->as.method_call.arg_count; i++) {
analyze_expr(node->as.method_call.args[i], scope, errors, a);
check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1],
sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL,
i, node, sym, errors, a);
}
node->type.kind = sym->return_type;
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name)
node->type.struct_name = sym->return_struct_type_name;
}
// === 表达式类型检查(调度器) ===
static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
switch (node->kind) {
case AST_LITERAL_EXPR: break;
case AST_IDENT_EXPR: analyze_ident_expr(node, scope, errors, a); break;
case AST_UNARY_EXPR: analyze_unary_expr(node, scope, errors, a); break;
case AST_BINARY_EXPR: analyze_binary_expr(node, scope, errors, a); break;
case AST_CALL_EXPR: analyze_call_expr(node, scope, errors, a); break;
case AST_FIELD_ACCESS: analyze_field_access(node, scope, errors, a); break;
case AST_STRUCT_INIT: analyze_struct_init(node, scope, errors, a); break;
case AST_ENUM_VARIANT: analyze_enum_variant(node, scope, errors, a); break;
case AST_INDEX_EXPR: analyze_index_expr(node, scope, errors, a); break;
case AST_METHOD_CALL: analyze_method_call(node, scope, errors, a); break;
case AST_IF_STMT: analyze_node(node, scope, errors, a); break;
case AST_BLOCK: analyze_node(node, scope, errors, a); break;
default: break;
}
}
static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
if (!node) return;
switch (node->kind) {
case AST_PROGRAM:
// Pass 0: 注册类型别名
for (size_t i = 0; i < node->as.program.alias_count; i++) {
AstNode* alias = node->as.program.type_aliases[i];
Symbol* sym = scope_insert(scope, a, alias->as.type_alias.name,
SYM_VARIABLE, alias->as.type_alias.aliased_type);
if (sym) {
sym->is_type_alias = true;
if (alias->as.type_alias.aliased_struct_name) {
sym->struct_type_name = alias->as.type_alias.aliased_struct_name;
}
}
}
// 注册枚举
for (size_t i = 0; i < node->as.program.enum_count; i++) {
AstNode* ed = node->as.program.enums[i];
scope_insert_enum(scope, a, ed->as.enum_decl.name,
ed->as.enum_decl.variants,
ed->as.enum_decl.variant_payload_types,
ed->as.enum_decl.variant_payload_struct_names,
ed->as.enum_decl.variant_count);
}
// 第一遍:收集所有结构体定义
for (size_t i = 0; i < node->as.program.struct_count; i++) {
AstNode* sd = node->as.program.structs[i];
const char** fnames = (const char**)arena_alloc_impl(a,
sd->as.struct_decl.field_count * sizeof(const char*));
TypeKind* ftypes = (TypeKind*)arena_alloc_impl(a,
sd->as.struct_decl.field_count * sizeof(TypeKind));
const char** fstruct_names = (const char**)arena_alloc_impl(a,
sd->as.struct_decl.field_count * sizeof(const char*));
for (size_t j = 0; j < sd->as.struct_decl.field_count; j++) {
fnames[j] = sd->as.struct_decl.fields[j]->as.parameter.name;
ftypes[j] = sd->as.struct_decl.fields[j]->as.parameter.type;
fstruct_names[j] = sd->as.struct_decl.fields[j]->as.parameter.struct_type_name;
}
scope_insert_struct(scope, a, sd->as.struct_decl.name,
fnames, ftypes, fstruct_names,
sd->as.struct_decl.field_count);
}
// 处理 impl 块:将方法名改写为 StructName$methodName
// 并自动添加 self 参数(第一个参数),然后注册为普通函数。
// 同时将改写后的方法追加到程序 functions 数组方便后续 codegen。
{
// 先统计需要新增多少个函数(impl 中的方法总数)
size_t extra_fn = 0;
for (size_t i = 0; i < node->as.program.impl_count; i++) {
AstNode* impl = node->as.program.impls[i];
extra_fn += impl->as.impl_block.method_count;
}
if (extra_fn > 0) {
AstNode** new_fns = (AstNode**)arena_alloc_impl(a,
(node->as.program.fn_count + extra_fn) * sizeof(AstNode*));
memcpy(new_fns, node->as.program.functions,
node->as.program.fn_count * sizeof(AstNode*));
size_t write_pos = node->as.program.fn_count;
for (size_t i = 0; i < node->as.program.impl_count; i++) {
AstNode* impl = node->as.program.impls[i];
const char* st_name = impl->as.impl_block.struct_name;
// 验证目标结构体存在
Symbol* st_sym = scope_lookup_struct(scope, st_name);
if (!st_sym) {
error_add(errors, "<sema>", impl->loc.line, impl->loc.col,
"impl 的目标结构体 '%s' 未定义", st_name);
continue;
}
for (size_t j = 0; j < impl->as.impl_block.method_count; j++) {
AstNode* method = impl->as.impl_block.methods[j];
// 构造改名后的函数名
char mangled[256];
snprintf(mangled, sizeof(mangled), "%s$%s", st_name,
method->as.function.name);
method->as.function.name = arena_strdup_impl(a, mangled,
strlen(mangled));
// 追加到新 functions 数组
new_fns[write_pos++] = method;
}
}
// 更新程序节点
node->as.program.functions = new_fns;
node->as.program.fn_count = node->as.program.fn_count + extra_fn;
}
}
// 第二遍:收集所有函数签名
for (size_t i = 0; i < node->as.program.fn_count; i++) {
AstNode* fn = node->as.program.functions[i];
TypeKind* pts = (TypeKind*)arena_alloc_impl(a, fn->as.function.param_count * sizeof(TypeKind));
const char** pnames = (const char**)arena_alloc_impl(a, fn->as.function.param_count * sizeof(const char*));
const char** pstruct_names = (const char**)arena_alloc_impl(a, fn->as.function.param_count * sizeof(const char*));
for (size_t j = 0; j < fn->as.function.param_count; j++) {
TypeKind pt = fn->as.function.params[j]->as.parameter.type;
const char* psn = fn->as.function.params[j]->as.parameter.struct_type_name;
const char* pn = fn->as.function.params[j]->as.parameter.name;
// 解析参数类型的别名
if (psn) {
Symbol* as = scope_lookup(scope, psn);
if (as && as->is_type_alias) {
pt = as->type;
psn = as->struct_type_name;
}
}
pts[j] = pt;
pnames[j] = pn;
pstruct_names[j] = psn;
}
// 解析返回类型的别名
TypeKind ret_t = fn->as.function.return_type;
const char* ret_sn = fn->as.function.return_struct_type_name;
if (ret_sn) {
Symbol* as = scope_lookup(scope, ret_sn);
if (as && as->is_type_alias) {
ret_t = as->type;
ret_sn = as->struct_type_name;
}
}
scope_insert_function(scope, a, fn->as.function.name,
ret_t, ret_sn,
pts, pnames, pstruct_names,
fn->as.function.param_count,
fn->as.function.type_params,
fn->as.function.type_param_count);
}
// 第三遍:分析每个函数体
for (size_t i = 0; i < node->as.program.fn_count; i++) {
analyze_node(node->as.program.functions[i], scope, errors, a);
}
break;
case AST_FUNCTION: {
Scope* fn_scope = scope_new(a, scope);
// 注册参数(同时解析类型别名,更新 AST 节点供 codegen 使用)
for (size_t i = 0; i < node->as.function.param_count; i++) {
AstNode* p = node->as.function.params[i];
TypeKind pt = p->as.parameter.type;
const char* psn = p->as.parameter.struct_type_name;
if (psn) {
Symbol* as = scope_lookup(scope, psn);
if (as && as->is_type_alias) {
pt = as->type;
psn = as->struct_type_name;
// 更新 AST 节点
p->as.parameter.type = pt;
p->as.parameter.struct_type_name = psn;
}
}
Symbol* sym = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, pt);
if (sym && pt == TYPE_STRUCT && psn) {
sym->struct_type_name = psn;
}
}
// 解析返回类型的别名
const char* ret_sn = node->as.function.return_struct_type_name;
if (ret_sn) {
Symbol* as = scope_lookup(scope, ret_sn);
if (as && as->is_type_alias) {
node->as.function.return_type = as->type;
node->as.function.return_struct_type_name = as->struct_type_name;
}
}
TypeKind saved = current_return_type;
const char* saved_name = current_return_struct_name;
current_return_type = node->as.function.return_type;
current_return_struct_name = node->as.function.return_struct_type_name;
analyze_node(node->as.function.body, fn_scope, errors, a);
current_return_type = saved;
current_return_struct_name = saved_name;
break;
}
case AST_BLOCK:
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
analyze_node(node->as.block.stmts[i], scope, errors, a);
}
// 表达式作为值: 块类型 = 最后一条产生值的语句类型
if (node->as.block.stmt_count > 0) {
AstNode* last = node->as.block.stmts[node->as.block.stmt_count - 1];
TypeKind ek = TYPE_VOID;
const char* esn = NULL;
if (last->kind == AST_EXPR_STMT) {
ek = last->as.expr_stmt.expr->type.kind;
esn = last->as.expr_stmt.expr->type.struct_name;
} else if (last->kind == AST_IF_STMT && last->type.kind != TYPE_VOID) {
ek = last->type.kind;
esn = last->type.struct_name;
}
if (ek != TYPE_ERROR && ek != TYPE_VOID) {
node->type.kind = ek;
node->type.struct_name = esn;
}
}
break;
case AST_LET_STMT: {
TypeKind var_type;
const char* var_struct_name = NULL;
bool is_array_type = false;
if (node->as.let_stmt.has_type_annot) {
if (node->as.let_stmt.annot_type == TYPE_ARRAY) {
// 数组类型标注: 跳过 init 分析 (init 是自引用的占位符)
is_array_type = true;
var_type = TYPE_ARRAY;
} else {
analyze_expr(node->as.let_stmt.init, scope, errors, a);
TypeKind inferred = node->as.let_stmt.init->type.kind;
const char* annot_struct = node->as.let_stmt.struct_type_name;
if (annot_struct) {
// 先检查是否是类型别名
Symbol* alias_sym = scope_lookup(scope, annot_struct);
if (alias_sym && alias_sym->is_type_alias) {
var_type = alias_sym->type;
var_struct_name = alias_sym->struct_type_name;
} else {
// struct 类型标注
Symbol* st_sym = scope_lookup_struct(scope, annot_struct);
if (!st_sym) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的类型 '%s'", annot_struct);
break;
}
var_type = TYPE_STRUCT;
var_struct_name = annot_struct;
}
} else {
var_type = node->as.let_stmt.annot_type;
}
bool literal_to_int = (inferred == TYPE_I64
&& node->as.let_stmt.init->kind == AST_LITERAL_EXPR
&& (var_type == TYPE_I32 || var_type == TYPE_U64 || var_type == TYPE_CHAR));
if (inferred != TYPE_ERROR && inferred != var_type
&& !can_implicit_convert(inferred, var_type)
&& !literal_to_int) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"变量 '%s' 类型标注为 '%s',但初始化表达式类型为 '%s'",
node->as.let_stmt.name,
annot_struct ? annot_struct : type_name(var_type),
type_name(inferred));
}
}
} else {
analyze_expr(node->as.let_stmt.init, scope, errors, a);
TypeKind inferred = node->as.let_stmt.init->type.kind;
// 类型推断
if (inferred == TYPE_ERROR || inferred == TYPE_VOID) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"无法从表达式推断变量 '%s' 的类型", node->as.let_stmt.name);
break;
}
var_type = inferred;
if (inferred == TYPE_STRUCT) {
var_struct_name = node->as.let_stmt.init->type.struct_name;
}
}
node->type.kind = var_type;
node->type.struct_name = var_struct_name;
if (is_array_type) {
node->type.element_type = node->as.let_stmt.annot_element_type;
node->type.element_struct_name = node->as.let_stmt.annot_element_struct_name;
node->type.array_size = node->as.let_stmt.annot_array_size;
}
Symbol* sym = scope_insert(scope, a, node->as.let_stmt.name, SYM_VARIABLE, var_type);
if (!sym) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"变量 '%s' 重复定义", node->as.let_stmt.name);
} else {
sym->is_mut = node->as.let_stmt.is_mut;
if (var_struct_name) {
sym->type = TYPE_STRUCT;
sym->struct_type_name = var_struct_name;
}
if (is_array_type) {
sym->array_element_type = node->type.element_type;
sym->array_element_struct_name = node->type.element_struct_name;
sym->array_size = node->type.array_size;
}
}
break;
}
case AST_ASSIGN_STMT: {
Symbol* sym = scope_lookup(scope, node->as.assign_stmt.name);
if (!sym) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的变量 '%s'", node->as.assign_stmt.name);
node->type.kind = TYPE_ERROR;
break;
}
if (sym->kind != SYM_VARIABLE) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'%s' 不是变量,不能赋值", node->as.assign_stmt.name);
node->type.kind = TYPE_ERROR;
break;
}
if (!sym->is_mut) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"不能对不可变变量 '%s' 赋值(需用 var 声明)",
node->as.assign_stmt.name);
node->type.kind = TYPE_ERROR;
break;
}
analyze_expr(node->as.assign_stmt.value, scope, errors, a);
TypeKind value_ty = node->as.assign_stmt.value->type.kind;
if (value_ty != TYPE_ERROR && value_ty != sym->type) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"赋值类型不匹配: 变量 '%s' 类型为 '%s',但表达式类型为 '%s'",
node->as.assign_stmt.name, type_name(sym->type), type_name(value_ty));
}
node->type.kind = TYPE_VOID;
break;
}
case AST_ARRAY_ASSIGN_STMT: {
Symbol* sym = scope_lookup(scope, node->as.array_assign.name);
if (!sym) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的变量 '%s'", node->as.array_assign.name);
node->type.kind = TYPE_ERROR; break;
}
if (sym->type != TYPE_ARRAY) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'%s' 不是数组类型,不能使用索引赋值", node->as.array_assign.name);
node->type.kind = TYPE_ERROR; break;
}
analyze_expr(node->as.array_assign.index, scope, errors, a);
analyze_expr(node->as.array_assign.value, scope, errors, a);
AstNode* idx = node->as.array_assign.index;
AstNode* val = node->as.array_assign.value;
if (idx->type.kind != TYPE_ERROR && idx->type.kind != TYPE_I64) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"数组索引必须是 i64 类型, 得到 '%s'", type_name(idx->type.kind));
}
TypeKind elem_kind = (sym->type == TYPE_ARRAY) ? sym->array_element_type : sym->type;
if (val->type.kind != TYPE_ERROR && val->type.kind != elem_kind) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"数组元素类型不匹配: 期望 '%s',得到 '%s'",
type_name(elem_kind), type_name(val->type.kind));
}
// struct 元素类型时还需检查结构体名是否匹配
if (val->type.kind != TYPE_ERROR && elem_kind == TYPE_STRUCT
&& sym->array_element_struct_name && val->type.struct_name) {
if (strcmp(sym->array_element_struct_name, val->type.struct_name) != 0) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"数组元素类型不匹配: 期望 '%s',得到 '%s'",
sym->array_element_struct_name, val->type.struct_name);
}
}
node->type.kind = TYPE_VOID;
break;
}
case AST_IF_STMT:
analyze_expr(node->as.if_stmt.cond, scope, errors, a);
if (node->as.if_stmt.cond->type.kind != TYPE_BOOL &&
node->as.if_stmt.cond->type.kind != TYPE_ERROR) {
error_add(errors, "<sema>", node->loc.line, node->loc.col, "if 条件必须是布尔类型");
}
analyze_node(node->as.if_stmt.then_block, scope, errors, a);
if (node->as.if_stmt.else_block) {
analyze_node(node->as.if_stmt.else_block, scope, errors, a);
}
// 表达式作为值: if 类型 = 两个分支的共同非 void 类型
{
AstNode* tb = node->as.if_stmt.then_block;
AstNode* eb = node->as.if_stmt.else_block;
if (tb && eb) {
TypeKind tt = tb->type.kind, et = eb->type.kind;
if (tt == et && tt != TYPE_VOID && tt != TYPE_ERROR) {
node->type.kind = tt;
if (tt == TYPE_STRUCT && tb->type.struct_name)
node->type.struct_name = tb->type.struct_name;
}
}
}
break;
case AST_WHILE_STMT:
analyze_expr(node->as.while_stmt.cond, scope, errors, a);
if (node->as.while_stmt.cond->type.kind != TYPE_BOOL &&
node->as.while_stmt.cond->type.kind != TYPE_ERROR) {
error_add(errors, "<sema>", node->loc.line, node->loc.col, "while 条件必须是布尔类型");
}
analyze_node(node->as.while_stmt.body, scope, errors, a);
break;
case AST_RETURN_STMT:
if (node->as.return_stmt.expr) {
analyze_expr(node->as.return_stmt.expr, scope, errors, a);
node->type.kind = node->as.return_stmt.expr->type.kind;
TypeKind actual = node->as.return_stmt.expr->type.kind;
TypeKind expected = current_return_type;
if (actual != TYPE_ERROR && expected != TYPE_VOID) {
if (expected == TYPE_STRUCT) {
// 结构体返回类型:比较具体类型名
const char* actual_name = node->as.return_stmt.expr->type.struct_name;
const char* expected_name = current_return_struct_name;
if (actual != TYPE_STRUCT || !actual_name || !expected_name ||
strcmp(actual_name, expected_name) != 0) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"返回类型不匹配: 期望 '%s',得到 '%s'",
expected_name ? expected_name : "struct",
actual_name ? actual_name : type_name(actual));
}
} else if (actual != expected &&
!(expected == TYPE_I64 && actual == TYPE_ENUM)) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"返回类型不匹配: 期望 '%s',得到 '%s'",
type_name(expected), type_name(actual));
}
}
} else if (current_return_type != TYPE_VOID) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"函数应返回值类型 '%s'",
current_return_struct_name ? current_return_struct_name : type_name(current_return_type));
}
break;
case AST_EXPR_STMT:
analyze_expr(node->as.expr_stmt.expr, scope, errors, a);
break;
default:
analyze_expr(node, scope, errors, a);
break;
}
}
void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) {
Scope* global_scope = scope_new(arena, NULL);
// 注册内置函数
TypeKind params_i64[] = {TYPE_I64};
scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, NULL, params_i64, NULL, NULL, 1, NULL, 0);
TypeKind params_f64[] = {TYPE_F64};
scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, NULL, params_f64, NULL, NULL, 1, NULL, 0);
TypeKind params_bool[] = {TYPE_BOOL};
scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, NULL, params_bool, NULL, NULL, 1, NULL, 0);
TypeKind params_str[] = {TYPE_STR};
scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, NULL, 1, NULL, 0);
analyze_node(ast, global_scope, errors, arena);
}